Creates an augmentation function that optionally utilizes cross-fitting
create.augmentation.function(
family,
crossfit = TRUE,
nfolds.crossfit = 10,
cv.glmnet.args = NULL
)
The response type (see options in glmnet
help file)
A logical value indicating whether to use cross-fitting (TRUE
) or not (FALSE
).
Cross-fitting is more computationally intensive, but helps to prevent overfitting, see Chernozhukov, et al. (2018)
An integer specifying the number of folds to use for cross-fitting. Must be greater than 1
A list of NAMED arguments to pass to the cv.glmnet
function. For
example, cv.glmnet.args = list(type.measure = "mse", nfolds = 10)
. See cv.glmnet
and glmnet
for all possible options.
A function which can be passed to the augment.func
argument of the fit.subgroup
function.
Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W., & Robins, J. (2018). Double/debiased machine learning for treatment and structural parameters https://arxiv.org/abs/1608.00060
fit.subgroup
for estimating ITRs and create.propensity.function
for creation of propensity functions
library(personalized)
set.seed(123)
n.obs <- 500
n.vars <- 15
x <- matrix(rnorm(n.obs * n.vars, sd = 3), n.obs, n.vars)
# simulate non-randomized treatment
xbetat <- 0.5 + 0.5 * x[,7] - 0.5 * x[,9]
trt.prob <- exp(xbetat) / (1 + exp(xbetat))
trt01 <- rbinom(n.obs, 1, prob = trt.prob)
trt <- 2 * trt01 - 1
# simulate response
# delta below drives treatment effect heterogeneity
delta <- 2 * (0.5 + x[,2] - x[,3] - x[,11] + x[,1] * x[,12] )
xbeta <- x[,1] + x[,11] - 2 * x[,12]^2 + x[,13] + 0.5 * x[,15] ^ 2
xbeta <- xbeta + delta * trt
# continuous outcomes
y <- drop(xbeta) + rnorm(n.obs, sd = 2)
aug.func <- create.augmentation.function(family = "gaussian",
crossfit = TRUE,
nfolds.crossfit = 10,
cv.glmnet.args = list(type.measure = "mae",
nfolds = 5))
prop.func <- create.propensity.function(crossfit = TRUE,
nfolds.crossfit = 10,
cv.glmnet.args = list(type.measure = "auc",
nfolds = 5))
if (FALSE) {
subgrp.model <- fit.subgroup(x = x, y = y,
trt = trt01,
propensity.func = prop.func,
augment.func = aug.func,
loss = "sq_loss_lasso",
nfolds = 10) # option for cv.glmnet (for ITR estimation)
summary(subgrp.model)
}