Introduction to the ale package

Chitu Okoli

2023-08-29

This vignette demonstrates the basic functionality of the ale package on standard large datasets used for machine learning. A separate vignette is devoted to its use on small datasets, as is often the case with statistical inference. (How small is small? That’s a tough question, but as that vignette explains, most datasets of less than 2000 rows are probably “small” and even many datasets that are more than 2000 rows are nonetheless “small”.)

diamonds dataset

For this introduction, we use the diamonds dataset, built-in with ggplot2 (a required package for ale): “a dataset containing the prices and other attributes of almost 54,000 diamonds”.

help(diamonds)
diamonds
#> # A tibble: 53,940 × 10
#>    carat cut       color clarity depth table price     x     y     z
#>    <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
#>  1  0.23 Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
#>  2  0.21 Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
#>  3  0.23 Good      E     VS1      56.9    65   327  4.05  4.07  2.31
#>  4  0.29 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
#>  5  0.31 Good      J     SI2      63.3    58   335  4.34  4.35  2.75
#>  6  0.24 Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48
#>  7  0.24 Very Good I     VVS1     62.3    57   336  3.95  3.98  2.47
#>  8  0.26 Very Good H     SI1      61.9    55   337  4.07  4.11  2.53
#>  9  0.22 Fair      E     VS2      65.1    61   337  3.87  3.78  2.49
#> 10  0.23 Very Good H     VS1      59.4    61   338  4     4.05  2.39
#> # ℹ 53,930 more rows
str(diamonds)
#> tibble [53,940 × 10] (S3: tbl_df/tbl/data.frame)
#>  $ carat  : num [1:53940] 0.23 0.21 0.23 0.29 0.31 0.24 0.24 0.26 0.22 0.23 ...
#>  $ cut    : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 2 3 3 3 1 3 ...
#>  $ color  : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 7 6 5 2 5 ...
#>  $ clarity: Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 2 6 7 3 4 5 ...
#>  $ depth  : num [1:53940] 61.5 59.8 56.9 62.4 63.3 62.8 62.3 61.9 65.1 59.4 ...
#>  $ table  : num [1:53940] 55 61 65 58 58 57 57 55 61 61 ...
#>  $ price  : int [1:53940] 326 326 327 334 335 336 336 337 337 338 ...
#>  $ x      : num [1:53940] 3.95 3.89 4.05 4.2 4.34 3.94 3.95 4.07 3.87 4 ...
#>  $ y      : num [1:53940] 3.98 3.84 4.07 4.23 4.35 3.96 3.98 4.11 3.78 4.05 ...
#>  $ z      : num [1:53940] 2.43 2.31 2.31 2.63 2.75 2.48 2.47 2.53 2.49 2.39 ...

For any valid machine learning analysis, We must first split the dataset into training and test samples. The model is developed on the training set and then evaluated on the test set. Although this is well-known to all machine learning scientists, we emphasize it here because it is not obvious to everyone that ALE results are only valid when run on a distinct test set. They are not valid when calculated on the same dataset that was used to train the model. (When a dataset is too small to feasibly split into training and test sets, then the ale package has tools to appropriately handle such small datasets.)

# Split the dataset into training and test sets
# https://stackoverflow.com/a/54892459/2449926
set.seed(0)
train_test_split <- sample(c(TRUE, FALSE), nrow(diamonds), replace = TRUE, prob = c(0.8, 0.2))
diamonds_train <- diamonds[train_test_split, ]
diamonds_test <- diamonds[!train_test_split, ]

So, now we split the dataset with an 80-20 split for a training set of 43192 rows and a test set of 10748 rows. Now we can build our model.

Modelling with general additive models (GAM)

ALE is a model-agnostic IML approach, that is, it works with any kind of machine learning model. As such, the ale works with any R model with the only condition that it can predict numeric outcomes (such as raw estimates for regression and probabilities or odds ratios for classification). For this demonstration, we will use general additive models (GAM), a relatively fast algorithm that models data more flexibly than ordinary least squares regression. It is beyond our scope here to explain how GAM works (you can learn more with Noam Ross’s excellent tutorial), but the examples here will work with any machine learning algorithm.

We train a GAM model to predict diamond price:

# Create a GAM model with flexible curves to predict diamond price
# Smooth all numeric variables and include all other variables
# Build model on training data, not on the full dataset.
gam_diamonds <- mgcv::gam(
  price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) + 
    cut + color + clarity,
  data = diamonds_train
  )
summary(gam_diamonds)
#> 
#> Family: gaussian 
#> Link function: identity 
#> 
#> Formula:
#> price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) + 
#>     cut + color + clarity
#> 
#> Parametric coefficients:
#>              Estimate Std. Error  t value Pr(>|t|)    
#> (Intercept)  3514.642     13.324  263.784  < 2e-16 ***
#> cut.L         240.309     38.737    6.204 5.57e-10 ***
#> cut.Q         -42.631     27.365   -1.558 0.119282    
#> cut.C          57.874     19.290    3.000 0.002700 ** 
#> cut^4          -4.875     13.250   -0.368 0.712925    
#> color.L     -2004.557     17.689 -113.324  < 2e-16 ***
#> color.Q      -677.079     16.059  -42.162  < 2e-16 ***
#> color.C      -151.158     14.955  -10.108  < 2e-16 ***
#> color^4        49.606     13.701    3.621 0.000294 ***
#> color^5       -96.562     12.937   -7.464 8.55e-14 ***
#> color^6       -41.294     11.752   -3.514 0.000442 ***
#> clarity.L    3733.499     31.062  120.196  < 2e-16 ***
#> clarity.Q   -1649.288     28.963  -56.945  < 2e-16 ***
#> clarity.C     767.873     24.730   31.050  < 2e-16 ***
#> clarity^4    -252.474     19.662  -12.840  < 2e-16 ***
#> clarity^5     179.128     15.982   11.208  < 2e-16 ***
#> clarity^6      35.125     13.877    2.531 0.011372 *  
#> clarity^7      84.787     12.253    6.920 4.59e-12 ***
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> Approximate significance of smooth terms:
#>            edf Ref.df       F p-value    
#> s(carat) 8.550  8.935  44.029 < 2e-16 ***
#> s(depth) 8.163  8.771   6.232 < 2e-16 ***
#> s(table) 5.749  6.840   3.451 0.00111 ** 
#> s(x)     8.527  8.893  69.579 < 2e-16 ***
#> s(y)     9.000  9.000 219.239 < 2e-16 ***
#> s(z)     9.000  9.000  13.735 < 2e-16 ***
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> R-sq.(adj) =  0.935   Deviance explained = 93.5%
#> GCV = 1.0469e+06  Scale est. = 1.0452e+06  n = 43192

ale function for generating ALE data and plots

The core function in the ale package is the ale function. Consistent with tidyverse conventions, its first argument is a dataset. Its second argument is a model object–any R model object that can generate numeric predictions is acceptable. By default, it generates ALE data and plots on all the input variables used for the model. To change these options (e.g., to calculate ALE for only a subset of variables; to output the data only or the plots only rather than both; or to use a custom, non-standard predict function for the model), see details in the help file for the function.

# help(ale::ale)
help(ale)

The ale function returns a list with one element per input variable, as well as a .common_data element that has some details about the outcome (y) variable. Each variable’s element consists of a list with two elements: the ALE data for that variable and a ggplot plot object that plots that ALE data.

To iterate the list and plot all the ALE plots, we provide here some demonstration code using the purrr package for list iteration and gridExtra for arranging multiple plots in a common plot grid.

# Simple ALE without bootstrapping
ale_gam_diamonds <- ale(diamonds_test, gam_diamonds)

ale_gam_diamonds[setdiff(names(ale_gam_diamonds), '.common_data')] |> 
  purrr::map(\(.x) .x$plot) |>  # extract plots as a list
  gridExtra::grid.arrange(grobs = _, ncol = 2)

Bootstrapped ALE

One of the key features of the ALE package is bootstrapping of the ALE results to ensure that the results are reliable, that is, generalizable to data beyond the sample on which the model was built. Again, we stress that the ale function is valid only when analyzed on test data distinct from the training data. When samples are too small for this, we provide a different bootstrapping method, model_bootstrap, explained in the vignette for small datasets.

Although ALE is faster than most other IML techniques for global explanation such as partial dependence plots (PDP) and SHAP, it still requires some time to run. Bootstrapping multiplies that time by the number of bootstrap iterations. Since this vignette is just a demonstration of package functionality rather than a real analysis, we will demonstrate bootstrapping on a small subset of the test data. This will run much faster as the speed of the ALE algorithm depends on the size of the dataset. So, let us take a random sample of 200 rows of the test set.

# Bootstraping is rather slow, so create a smaller subset of new data for demonstration
new_rows <- sample(nrow(diamonds_test), 200, replace = FALSE)
diamonds_new <- diamonds_test[new_rows, ]

Now we create bootstrapped ALE data and plots using the boot_it argument. ALE is a relatively stable IML algorithm (compared to others like PDP), so 100 bootstrap samples should be sufficient for relatively stable results, especially for model development. Final results could be confirmed with 1000 bootstrap samples or more, but there should not be much difference in the results beyond 100.

ale_gam_diamonds_boot <- ale(diamonds_new, gam_diamonds, boot_it = 100)

# Bootstrapping produces confidence intervals
ale_gam_diamonds_boot[setdiff(names(ale_gam_diamonds_boot), '.common_data')] |> 
  purrr::map(\(.x) .x$plot) |>  # extract plots as a list
  gridExtra::grid.arrange(grobs = _, ncol = 2)

In this case, the bootstrapped results are mostly similar to simple ALE result, though the results for the ´x´ input variable are rather different. In principle, we should always bootstrap the results and trust only in bootstrapped results.

ALE interactions

Another advantage of ALE is that it provides data for two-way interactions between variables. This is implemented with the ale::ale_ixn function. Like the ale function, ale_ixn similarly requires an input dataset and a model object. By default, it generates ALE data and plots on all possible pairs of input variables used for the model. However, an ALE interaction requires at least one of the variables to be numeric. So, ale_ixn has a notion of x1 and x2 variables; the x1 variable must be numeric whereas the x2 can be of any input datatype. To change the default options (e.g., to calculate interactions for only certain pairs of variables), see details in the help file for the function.

help(ale_ixn)
# help(ale::ale_ixn)

Like the ale function, the ale_ixn returns a list with one element per input x1 variable, as well as a .common_data element with details about the outcome (y) variable. However, in this case, each variable’s element consists of a list of all the x2 variables for which the x1 interaction is calculated. Each x2 element then has two elements: the ALE data for that variable and a ggplot plot object that plots that ALE data.

Again, to iterate the list and plot all the ALE plots, we provide here some demonstration code. It is different from the plot code given earlier because of the two levels of interacting variables in the output data.

# ALE two-way interactions
ale_ixn_gam_diamonds <- ale_ixn(diamonds_test, gam_diamonds)

# Skip .common_data when iterating through the data for plotting
ale_ixn_gam_diamonds[setdiff(names(ale_ixn_gam_diamonds), '.common_data')] |> 
  purrr::walk(\(x1) {  # extract list of x1 ALE outputs
    purrr::map(x1, \(.x) .x$plot) |>  # for each x1, extract list of x2 ALE outputs
      gridExtra::grid.arrange(grobs = _, ncol = 2)  # plot all x1 plots
  })