ALEPlot
and
ale
packageslibrary(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
The ALEPlot
package is the reference implementation of the brilliant idea of accumulated local effects
(ALE) by Daniel Apley and Jingyu Zhu. However, it has not been
updated since 2018. The ale
package
attempts to rewrite and extend the original base work.
In developing the ale
package, we must ensure that we
correctly implement the original ALE algorithm while extending it.
Indeed, some permanent unit tests call ALEPlot
to make sure
that ale
provides identical results for identical inputs.
We thought that presenting some of these comparisons as a vignette might
be helpful. We focus here on the examples
from the ALEPlot
package so that results are directly
comparable.
Other than its extensions for ALE-based
statistics, here are some of the main points in which
ale
differs from ALEPlot
where it provides
otherwise similar functionality:
ggplot2
instead of base R graphics. We consider
ggplot2
to be a more modern and versatile graphics
system.relative_y
argument, as we demonstrate.) We believe that
such plots more easily interpretable.One notable difference between the two packages is that the
ale
package does not and will not implement partial
dependency plots (PDP). The package is focused exclusively on
accumulated local effects (ALE); users who need PDPs may use
ALEPlot
or other implementations.
In each section here, we cover an example from ALEPlot
and then reimplement it with ale
.
We begin with the second code example directly from the
ALEPlot
package. (We skip the first example because it is a
subset of the second, simply without interactions.) Here is the code
from the example to create a simulated dataset and train a neural
network on it:
## R code for Example 2
## Load relevant packages
library(ALEPlot)
library(nnet)
## Generate some data and fit a neural network supervised learning model
set.seed(0) # not in the original, but added for reproducibility
n = 5000
x1 <- runif(n, min = 0, max = 1)
x2 <- runif(n, min = 0, max = 1)
x3 <- runif(n, min = 0, max = 1)
x4 <- runif(n, min = 0, max = 1)
y = 4*x1 + 3.87*x2^2 + 2.97*exp(-5+10*x3)/(1+exp(-5+10*x3))+
13.86*(x1-0.5)*(x2-0.5)+ rnorm(n, 0, 1)
DAT <- data.frame(y, x1, x2, x3, x4)
nnet.DAT <- nnet(y~., data = DAT, linout = T, skip = F, size = 6,
decay = 0.1, maxit = 1000, trace = F)
For the demonstration, x1
has a linear relationship with
y
, x2
and x3
have non-linear
relationships, and x4
is a random variable with no
relationship with y
. x1
and x2
interact with each other in their relationship with y
.
To create ALE data and plots, ALEPlot
requires the
creation of a custom prediction function:
## Define the predictive function
yhat <- function(X.model, newdata) as.numeric(predict(X.model, newdata,
type = "raw"))
Now the ALEPlot
function can be called to create the ALE
data and plot it. The function returns a specially formatted list with
the ALE data; it can be saved for subsequent custom plotting.
## Calculate and plot the ALE main effects of x1, x2, x3, and x4
ALE.1 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 1, K = 500,
NA.plot = TRUE)
In the ALEPlot
implementation, calling the function
automatically prints a plot. While this provides some convenience if
that is what the user wants, it is not so convenient if the user does
not want to print a plot at the very point of ALE creation. It is
particularly inconvenient for script building. Although it is possible
to configure R to suspend graphic output before the ALEPlot
is called and then restart it after the function call, this is not so
straightforward—the function itself does not give any option to control
this behaviour.
ALE interactions can also be calculated and plotted:
## Calculate and plot the ALE second-order effects of {x1, x2} and {x1, x4}
ALE.12 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = c(1,2), K = 100,
NA.plot = TRUE)
If the output of the ALEPlot
has been saved to
variables, then its contents can be plotted with finer user control
using the generic R plot
method:
## Manually plot the ALE main effects on the same scale for easier comparison
## of the relative importance of the four predictor variables
par(mfrow = c(3,2))
plot(ALE.1$x.values, ALE.1$f.values, type="l", xlab="x1",
ylab="ALE_main_x1", xlim = c(0,1), ylim = c(-2,2), main = "(a)")
plot(ALE.2$x.values, ALE.2$f.values, type="l", xlab="x2",
ylab="ALE_main_x2", xlim = c(0,1), ylim = c(-2,2), main = "(b)")
plot(ALE.3$x.values, ALE.3$f.values, type="l", xlab="x3",
ylab="ALE_main_x3", xlim = c(0,1), ylim = c(-2,2), main = "(c)")
plot(ALE.4$x.values, ALE.4$f.values, type="l", xlab="x4",
ylab="ALE_main_x4", xlim = c(0,1), ylim = c(-2,2), main = "(d)")
## Manually plot the ALE second-order effects of {x1, x2} and {x1, x4}
image(ALE.12$x.values[[1]], ALE.12$x.values[[2]], ALE.12$f.values, xlab = "x1",
ylab = "x2", main = "(e)")
contour(ALE.12$x.values[[1]], ALE.12$x.values[[2]], ALE.12$f.values, add=TRUE,
drawlabels=TRUE)
image(ALE.14$x.values[[1]], ALE.14$x.values[[2]], ALE.14$f.values, xlab = "x1",
ylab = "x4", main = "(f)")
contour(ALE.14$x.values[[1]], ALE.14$x.values[[2]], ALE.14$f.values, add=TRUE,
drawlabels=TRUE)
ale
package equivalentNow we demonstrate the same functionality with the ale
package. We will work with the same model on the same data, so we will
not create them again.
To create the model, we invoke the ale
which returns a
list with various ALE elements.
library(ale)
#> Loading required package: ggplot2
nn_ale <- ale(DAT, nnet.DAT, pred_type = "raw")
#>
Calculating ALE ■ 0% | ETA: ?
#>
Calculating ALE ■■■■■■■■■ 25% | ETA: 1s
Calculating ALE ■■■■■■■■■■■■■■■■■■■■■■■
#> 75% | ETA: 0s
Here are some notable differences compared to
ALEPlot
:
ALEPlot
that functions on only one variable at a
time, ale
generates ALE data for multiple variables in a
dataset at once. By default, it generates ALE elements for all the
predictor variables in the dataset that it is given; the user can
specify a single variable or any subset of variables. We will cover more
details in another vignette, but for our purposes here, we note the
data
element that returns a list of the ALE data for each
variable and the plots
element returns a list of
ggplot
plots.ale
creates a default generic predict function that
matches most standard R models. When the prediction type is not the
default “response”, as in our case, the user can set the desired type
with the pred_type
argument. However, for more complex or
non-standard prediction functions, ale
supports custom
functions with the pred_fun
argument.Since the plots are saved as a list, they can easily be printed out all at once:
The ale
package plots have various features that enhance
interpretability:
It might not be clear that the previous plots display exactly the
same data as those shown above from ALEPlot
. To make the
comparison clearer, we can recalculate the ALEs on a zero-centred
scale:
# Zero-centred ALE
nn_ale <- ale(DAT, nnet.DAT, pred_type = "raw", relative_y = 'zero')
#>
Calculating ALE ■ 0% | ETA: ?
Calculating ALE ■■■■■■■■■ 25% | ETA: 0s
Calculating
#> ALE ■■■■■■■■■■■■■■■■ 50% | ETA: 0s
gridExtra::grid.arrange(grobs = nn_ale$plots, ncol = 2)
With these zero-centred plots, the full range of y values and the rug plots give some context that aids interpretation. (If the rugs look slightly different, it is because they are randomly jittered to avoid overplotting.)
The ale
also produces interaction plots. Unlike
ALEPlot
, it does so with a separate dedicated function,
ale_ixn
. By default, it calculates all possible two-way
interactions between all variables in the data.
Because the variables interact with each other, the output data
structure is a two-layer list, so the print code is slightly more
complicated. However, the sample code we provide here using functions
from the purrr
package for iterating through lists and
gridExtra
package for arranging plots are reusable in any
application.
# Create and plot interactions
nn_ale_ixn <- ale_ixn(DAT, nnet.DAT, pred_type = "raw")
#>
Calculating ALE interactions ■ 0% | ETA: ?
Calculating ALE interactions
#> ■■■■■■■■■ 25% | ETA: 2s
Calculating ALE interactions ■■■■■■■■■■■■■■■■ 50% | ETA:
#> 1s
Calculating ALE interactions ■■■■■■■■■■■■■■■■■■■■■■■ 75% | ETA: 0s
# Print plots
nn_ale_ixn$plots |>
purrr::walk(\(.x1) { # extract list of x1 ALE outputs
gridExtra::grid.arrange(grobs = .x1, ncol = 1) # plot all x1 plots
})
These interaction plots are heat maps that indicate the interaction
regions that are above or below the average value of y with colours.
Grey indicates no meaningful interaction; blue indicates a positive
interaction effect; red indicates a negative effect. We find these
easier to interpret than the contour maps from ALEPlot
,
especially since the colours in each plot are on the same scale and so
the plots are directly comparable with each other.
The range of outcome (y) values is divided into quantiles, deciles by
default. However, the middle quantiles are modified. Rather than showing
the middle 10% or 20% of values, it is much narrow: it shows the middle
5%. (This value is based on the notion of alpha of 0.05 for confidence
intervals; it can be customized with the median_band
argument.)
The legend shows the midpoint y value of each quantile, which is usually the mean of the boundaries of the quantile. The exception is the special middle quantile, whose displayed midpoint value is the median of the entire dataset.
The interpretation of these interaction plots is that in any given region, the interaction between x1 and x2 increases (blue) or decreases (red) y by the amount indicated over and above the separate individual direct effects of x1 and x2 shown in the one-way ALE plots above. It is not an indication of the total effect of both variables together but rather of the additional effect of their interaction- beyond their individual effects. Thus, only the x1-x2 interaction shows any effect. For the interactions with x3, even though x3 indeed has a strong effect on y as we see in the one-way ALE plot above, it has no additional effect in interaction with the other variables, and so its interaction plots are entirely grey.
The next code example from the ALEPlot
package analyzes
a real dataset with a binary outcome variable. Whereas the
ALEPlot
has the user load a CSV file that might not be
readily available, we make that dataset available as the census dataset.
We load it here with the adjustments necessary to run the
ALEPlot
example.
## R code for Example 3
## Load relevant packages
library(ALEPlot)
library(gbm)
#> Loaded gbm 2.1.8.1
## Read data and fit a boosted tree supervised learning model
data(census, package = 'ale') # load ale package version of the data
data <-
census |>
as.data.frame() |> # ALEPlot is not compatible with the tibble format
select(age:native_country, higher_income) |> # Rearrange columns to match ALEPlot order
na.omit(data)
Although gradient boosted trees generally perform quite well, they
are rather slow. Rather than having you wait for it to run, the code
here downloads a pretrained GBM model. However, the code used to
generate it is provided in comments so that you can see it and run it
yourself if you want to. Note that the model calls is based on
data[,-c(3,4)]
, which drops the third and fourth variables
(fnlwgt
and education
, respectively).
# To generate the code, uncomment the following lines.
# But it is slow, so this vignette loads a pre-created model object.
# set.seed(0)
# gbm.data <- gbm(higher_income ~ ., data= data[,-c(3,4)],
# distribution = "bernoulli", n.trees=6000, shrinkage=0.02,
# interaction.depth=3)
# saveRDS(gbm.data, file.choose())
gbm.data <- url('https://github.com/Tripartio/ale/raw/main/download/gbm.data_model.rds') |>
readRDS()
gbm.data
#> gbm(formula = higher_income ~ ., distribution = "bernoulli",
#> data = data[, -c(3, 4)], n.trees = 6000, interaction.depth = 3,
#> shrinkage = 0.02)
#> A gradient boosted model with bernoulli loss function.
#> 6000 iterations were performed.
#> There were 12 predictors of which 12 had non-zero influence.
As before, we create a custom prediction function and then call the
ALEPlot
function to generate the plots. The prediction type
here is “link”, which represents the log odds in the gbm
package.
Creation of the ALE plots here is rather slow because the
gbm
predict function is slow. In this example, only
age
, education_num
(number of years of
education), and hours_per_week
are plotted, along with the
interaction between age
and
hours_per_week
.
## Define the predictive function; note the additional arguments for the
## predict function in gbm
yhat <- function(X.model, newdata) as.numeric(predict(X.model, newdata,
n.trees = 6000, type="link"))
## Calculate and plot the ALE main and interaction effects for x_1, x_3,
## x_11, and {x_1, x_11}
par(mfrow = c(2,2), mar = c(4,4,2,2)+ 0.1)
ALE.1=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=1, K=500,
NA.plot = TRUE)
ALE.3=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=3, K=500,
NA.plot = TRUE)
ALE.11=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=11, K=500,
NA.plot = TRUE)
ALE.1and11=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=c(1,11),
K=50, NA.plot = FALSE)
ale
package equivalentHere is the analogous code using the ale
package. In
this case, we also need to define a custom predict function because of
the particular n.trees = 6000
argument. To speed things up,
we provide a pretrained ale
object. This is possible
because ale
returns objects with data and plots bundled
together with no side effects (like automatic printing of created
plots). (It is probably possible to similarly cache ALEPlot
ALE objects, but it is not quite as straightforward.)
We display all the plots because it is easy to do so with the
ale
package but we focus on age
,
education_num
, and hours_per_week
for
comparison with ALEPlot. If the shapes of these plots look different, it
is because ale
tries as much as possible to display plots
on the same y-axis coordinate scale for easy comparison across
plots.
# Custom predict function that returns log odds
yhat <- function(object, newdata) {
as.numeric(
predict(object, newdata, n.trees = 6000,
type="link") # return log odds
)
}
# Generate ALE data for all variables
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_link <- ale(
# data[,-c(3,4)], gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
# relative_y = 'zero' # compatibility with ALEPlot
# )
# saveRDS(gbm_ale_link, file.choose())
gbm_ale_link <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_link.rds') |>
readRDS()
# Print plots
gridExtra::grid.arrange(grobs = gbm_ale_link$plots, ncol = 2)
Now we generate ALE data for all two-way interactions and then plot
them. Again, note the interaction between age
and
hours_per_week
. The interaction is minimal except for the
extremely high cases of hours per week.
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_ixn_link <- ale_ixn(
# data[,-c(3,4)], gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
# relative_y = 'zero' # compatibility with ALEPlot
# )
# saveRDS(gbm_ale_ixn_link, file.choose())
gbm_ale_ixn_link <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_ixn_link.rds') |>
readRDS()
# Print plots
gbm_ale_ixn_link$plots |>
purrr::walk(\(.x1) { # extract list of x1 ALE outputs
gridExtra::grid.arrange(grobs = .x1, ncol = 2) # plot all x1 interaction plots
})
Log odds are not necessarily the most interpretable way to express probabilities (though we will show shortly that they are sometimes uniquely valuable). So, we repeat the ALE creation using the “response” prediction type for probabilities and the default median centring of the plots.
As we can see, the shapes of the plots are similar, but the y axes are more easily interpretable as the probability (from 0 to 1) that a census respondent is in the higher income category. The median of around 10% or so indicates the median prediction of the GBM model: half of the respondents were predicted to have higher than a 10% likelihood of being higher income and half were predicted to have lower likelihood. The y-axis rug plots indicate that the predictions were generally rather extreme, either relatively close to 0 or 1, with few predictions in the middle.
# Custom predict function that returns predicted probabilities
yhat <- function(object, newdata) {
as.numeric(
predict(object, newdata, n.trees = 6000,
type="response") # return predicted probabilities
)
}
# Generate ALE data for all variables
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_prob <- ale(
# data[,-c(3,4)], gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600 # technical issue: rug_sample_size must be > x_intervals + 1
# )
# saveRDS(gbm_ale_prob, file.choose())
gbm_ale_prob <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_prob.rds') |>
readRDS()
# Print plots
gridExtra::grid.arrange(grobs = gbm_ale_prob$plots, ncol = 2)
Finally, we again generate two-way interactions, this time based on probabilities instead of on log odds. However, probabilities might not be the best choice for indicating interactions because, as we see from the rugs in the one-way ALE plots, the GBM model heavily concentrates its probabilities in the extremes near 0 and 1. Thus, the plots’ suggestions of strong interactions are likely exaggerated. In this case, the log odds ALEs shown above are probably more relevant.
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_ixn_prob <- ale_ixn(
# data[,-c(3,4)], gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600 # technical issue: rug_sample_size must be > x_intervals + 1
# )
# saveRDS(gbm_ale_ixn_prob, file.choose())
gbm_ale_ixn_prob <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_ixn_prob.rds') |>
readRDS()
# Print plots
gbm_ale_ixn_prob$plots |>
purrr::walk(\(.x1) { # extract list of x1 ALE outputs
gridExtra::grid.arrange(grobs = .x1, ncol = 2) # plot all x1 plots
})