Introduction to bartcs

Yeonghoon Yoo

set.seed(42)
library(bartcs)

The bartcs package finds confounders and treatment effect with Bayesian Additive Regression Trees (BART).

Tutorial with IHDP dataset

This tutorial will use The Infant Health and Development Program (IHDP) dataset. IHDP is a randomized experiment from 1985 to 1988 which studied the effect of home visits on cognitive test scores for infants. Our version is a synthetic dataset generated by Louizos et al. (2017) which provides true values to compare with. The dataset consists of 6 continuous and 19 binary covariates with simulated outcome which is a cognitive test score.

data(ihdp, package = "bartcs")

fit <- single_bart(
  Y               = ihdp$y_factual,
  trt             = ihdp$treatment,
  X               = ihdp[, 6:30],
  num_tree        = 10,
  num_chain       = 4,
  num_post_sample = 100,
  num_burn_in     = 100,
  verbose         = FALSE
)
#> Warning in .recacheSubclasses(def@className, def, env): undefined subclass
#> "ndiMatrix" of class "replValueSp"; definition not updated
fit
#> `bartcs` fit by `single_bart()`
#> 
#>         mean     2.5%    97.5%
#> ATE 3.989180 3.757940 4.194852
#> Y1  6.414511 6.209326 6.594922
#> Y0  2.425331 2.352195 2.512516

You can get mean and 95% credible interval of average treatment effect (ATE) and possible outcome Y1 and Y0.

ATE <- mean(ihdp$mu1 - ihdp$mu0)
ATE
#> [1] 4.016067
mu1 <- mean(ihdp$mu1)
mu1
#> [1] 6.44858
mu0 <- mean(ihdp$mu0)
mu0
#> [1] 2.432513
mse <- mean((unlist(fit$mcmc_list[, "ATE"]) - ATE)^2)
mse
#> [1] 0.01311197

True values of ATE, mu1 and mu0 all lies in 95% credible interval. Also, MSE between predicted and true values is 0.013.

Result

Both separate_bart() and single_bart() fits multiple MCMC chains. summary() provides result for each and aggregated chain.

summary(fit)
#> `bartcs` fit by `single_bart()`
#> 
#> Treatment Value
#>   Treated group    :      1
#>   Control group    :      0
#> 
#> Tree Parameters
#>   Number of Tree   :     10      Value  of alpha    :   0.95
#>   Prob.  of Grow   :   0.28      Value  of beta     :      2
#>   Prob.  of Prune  :   0.28      Value  of nu       :      3
#>   Prob.  of Change :   0.44      Value  of q        :   0.95
#> 
#> Chain Parameters
#>   Number of Chains :      4      Number of burn-in  :    100
#>   Number of Iter   :    200      Number of thinning :      1
#>   Number of Sample :    100
#> 
#> Outcome 
#>  estimand chain     2.5%       1Q     mean   median       3Q    97.5%
#>       ATE     1 3.773543 3.911107 3.972227 3.970111 4.056683 4.151239
#>       ATE     2 3.772706 3.941568 3.999966 4.008911 4.055016 4.193598
#>       ATE     3 3.741889 3.859122 3.943162 3.959975 4.021592 4.118731
#>       ATE     4 3.819920 3.971550 4.041364 4.045832 4.111067 4.252098
#>       ATE   agg 3.757940 3.914043 3.989180 3.990748 4.061248 4.194852
#>        Y1     1 6.215452 6.328724 6.396984 6.404919 6.460931 6.543268
#>        Y1     2 6.261225 6.359134 6.426559 6.434194 6.486718 6.590550
#>        Y1     3 6.184771 6.320276 6.377308 6.382648 6.450492 6.545511
#>        Y1     4 6.245013 6.393345 6.457192 6.465348 6.513756 6.672324
#>        Y1   agg 6.209326 6.347214 6.414511 6.423418 6.480329 6.594922
#>        Y0     1 2.342513 2.399295 2.424757 2.423876 2.451448 2.506433
#>        Y0     2 2.355629 2.398670 2.426593 2.426521 2.448976 2.495499
#>        Y0     3 2.357711 2.401557 2.434146 2.431961 2.460607 2.517344
#>        Y0     4 2.346422 2.389331 2.415827 2.412804 2.441819 2.511036
#>        Y0   agg 2.352195 2.397022 2.425331 2.424030 2.450444 2.512516

Data Visualization

You can get posterior inclusion probability for each variables.

plot(fit, method = "pip")

Since inclusion_plot() is a wrapper function of ggcharts::bar_chart(), you can use its arguments for better plot.

plot(fit, method = "pip", top_n = 10)

plot(fit, method = "pip", threshold = 0.5)

With trace_plot(), you can visually check trace of effects or other parameters.

plot(fit, method = "trace")

Connection to coda Package

You can also use functions from coda package for components of bartcs object. Components with mcmc_ prefix are mcmc.list object from coda package.

coda::gelman.diag(fit$mcmc_list[, "ATE"])
#> Potential scale reduction factors:
#> 
#>      Point est. Upper C.I.
#> [1,]        1.1       1.28

Multi-threading

count_omp_thread()
#> [1] 6

Check whether OpenMP is supported. You need more than 1 thread for multi-threading. Due to overhead of multi-threading, using parallelization will NOT be effective with small and moderate size datasets.

For comparison purpose, I will create dataset with 40,000 rows by bootstrapping from IHDP dataset. Then, for fast computation, I will fit the model with most parameters set to 1.

idx  <- sample(nrow(ihdp), 4e4, TRUE)
ihdp <- ihdp[idx, ]

microbenchmark::microbenchmark(
  simple = single_bart(
    Y               = ihdp$y_factual,
    trt             = ihdp$treatment,
    X               = ihdp[, 6:30],
    num_tree        = 1,
    num_chain       = 1,
    num_post_sample = 10,
    num_burn_in     = 0,
    verbose         = FALSE,
    parallel        = FALSE
  ),
  parallel = single_bart(
    Y               = ihdp$y_factual,
    trt             = ihdp$treatment,
    X               = ihdp[, 6:30],
    num_tree        = 1,
    num_chain       = 1,
    num_post_sample = 10,
    num_burn_in     = 0,
    verbose         = FALSE,
    parallel        = TRUE
  ),
  times = 50
)
#> Unit: milliseconds
#>      expr      min       lq     mean   median       uq      max neval
#>    simple 115.4550 128.5230 135.5252 133.5138 142.9197 156.5797    50
#>  parallel 102.2961 113.2127 118.8925 117.3296 121.2551 229.1426    50

Result show that parallelization reduces computation time.

References

Louizos, Christos, Uri Shalit, Joris M Mooij, David Sontag, Richard Zemel, and Max Welling. 2017. “Causal Effect Inference with Deep Latent-Variable Models.” Advances in Neural Information Processing Systems 30. https://doi.org/10.48550/arXiv.1705.08821.

mirror server hosted at Truenetwork, Russian Federation.