Valid Prediction

This article provides more detail on how bases ensures valid prediction, i.e., prevents any data leakage when new predictions are made.

Every basis function provides the makepredictcall() generic, which is called by model.frame() and whose job it is to save any statistics used by the basis expansion (such as a set of randomly sampled frequencies and phase shifts) for reuse later.

Basis functions support the predict() generic, so that if they are called outside of a model formula, they can be updated with new data. Behind the scenes, predict() for the various basis functions is just a small wrapper around makepredictcall().

To demonstrate these points, we will use the b_rff() basis function, which uses random features. However, the features are sampled once on construction and then retained for further use. First, in the modeling context, we’ll fit a model with b_rff() in the formula.

library(bases)
data(mtcars)

m = lm(mpg ~ b_rff(cyl, disp, hp, wt, p = 10), mtcars)

Repeated calls to predict() will yield the same predictions, even if the newdata argument is not empty.

all.equal(predict(m), predict(m, newdata = mtcars))
#> [1] TRUE
all.equal(predict(m, newdata = mtcars[5:10, ]), 
          predict(m, newdata = mtcars[5:10, ]))
#> [1] TRUE

The same is true if b_rff() is used outside of a formula.

B = with(mtcars, b_rff(cyl, disp, hp, wt, p = 10))

all.equal(B, predict(B))
#> [1] TRUE
all.equal(B, predict(B, newdata = mtcars), check.attributes = FALSE)
#> [1] TRUE
nrow(predict(B, newdata = mtcars[1:3, ]))
#> [1] 3

mirror server hosted at Truenetwork, Russian Federation.