Introduction to R Package ARTtransfer for Transfer Learning

Boxiang Wang, Yunan Wu, and Chenglong Ye

2024-10-16

Introduction to ARTtransfer

The ARTtransfer package implements Adaptive and Robust Transfer Learning (ART), a framework that enhances model performance on primary tasks by integrating auxiliary data from related domains. The goal of ART is to leverage information from these auxiliary data while being robust against the so-called negative transfer, meaning that the performance of the primary task will not be negatively affected by non-informative auxiliary data.

Installation

To install the development version from GitHub, run the following:

# Install the R package from CRAN
install.packages("ARTtransfer")

Getting Started

This section demonstrates how to generate synthetic data for transfer learning and apply the ART framework using different models.

Generate Data

The function generate_data() allows you to simulate data for transfer learning, including primary, auxiliary, and noisy datasets. The response can be either continuous or binary for regression and classification tasks, respectively.

library(ARTtransfer)

# Generate synthetic datasets for transfer learning
dat <- generate_data(n0 = 100, K = 3, nk = 50, p = 10, mu_trgt = 1, xi_aux = 0.5, 
                     ro = 0.3, err_sig = 1, is_test = TRUE, task = "classification")

# Explore the generated data
cat("Primary dataset (X):", dim(dat$X), "\n")
cat("Auxiliary dataset 1 (X_aux[[1]]):", dim(dat$X_aux[[1]]), "\n")
cat("Test dataset (X_test):", dim(dat$X_test), "\n")

Fitting ART

Once the data is generated, you can use the ART() function to apply transfer learning. In this example, we fit a logistic regression using the wrapper function fit_logit(), which implements the glm() function in R.

# Fit the ART model using generalized linear model (logistic regression)
fit_logit <- ART(X = dat$X, y = dat$y, X_aux = dat$X_aux, y_aux = dat$y_aux, 
                 X_test = dat$X_test, func = fit_logit)

# Summary of the fit
summary(fit_logit)

Wrapper Functions

The ARTtransfer package provides several wrapper functions for model fitting, including:

  • fit_lm(): Linear regression
  • fit_logit(): Logistic regression
  • fit_random_forest(): Random forest
  • fit_nnet(): Neural network
  • fit_gbm(): Gradient boosting machine

You can pass any of these functions to ART() using the func argument. You may also write your own function following the format of these wrappers. Specifically, the following requirements must be satisfied:

  • It accepts the required arguments: X, y, X_val, y_val, X_test, min_prod, max_prod.
  • It returns a list containing at least dev (the deviance) and pred (the predictions).
  • If is_coef = TRUE and a regression model is being used, the function must return the coefficients (coef).

ART with Integrated-Aggregating Machines (ART-I-AM)

ART_I_AM is an extension of the ART framework that automatically integrates three pre-defined models: random forest, AdaBoost, and neural networks. Users don’t need to specify the functions for each dataset as the models are already built into the function. However, users can follow the format of ART_I_AM and implement their own functions under this flexible and general framework.

Fitting ART with ART_I_AM

To use ART_I_AM, simply provide the primary and auxiliary datasets, and the integrated models will be applied automatically. We here demonstrate the method ART-I-AM by implementing three methods random forest, boosting, and neural networks. Users can modify the R function ART_I_AM by using any other functions, while the function should be coded to have the same format with func in the R function ART.

# Fit the ART_I_AM model using integrated Random Forest, AdaBoost, and Neural Network models
fit_I_AM <- ART_I_AM(X = dat$X, y = dat$y, 
            X_aux = dat$X_aux, y_aux = dat$y_aux, X_test = dat$X_test)

# View the predictions and weights
fit_I_AM$pred_FTL
fit_I_AM$W_FTL

Conclusion

The ARTtransfer package provides a flexible framework for performing adaptive and robust transfer learning using various models. You can easily integrate your own functions and data to improve performance on primary tasks.

mirror server hosted at Truenetwork, Russian Federation.