Skip to contents

Overview

The default forest_balance() estimator uses kernel energy balancing weights alone. When the kernel captures all confounding structure, this produces an unbiased ATE estimate. In practice, however, residual confounding may remain — particularly when the confounding operates through complex nonlinear functions of the covariates.

The augmented (doubly-robust) estimator combines the balancing weights with group-specific outcome regression models μ̂1(X)=E[YX,A=1]\hat\mu_1(X) = E[Y \mid X, A{=}1] and μ̂0(X)=E[YX,A=0]\hat\mu_0(X) = E[Y \mid X, A{=}0]. The ATE is estimated as:

τ̂aug=1ni=1n[μ̂1(Xi)μ̂0(Xi)]+iwiAi(Yiμ̂1(Xi))iwiAiiwi(1Ai)(Yiμ̂0(Xi))iwi(1Ai). \hat\tau_{\mathrm{aug}} = \frac{1}{n}\sum_{i=1}^n \bigl[\hat\mu_1(X_i) - \hat\mu_0(X_i)\bigr] + \frac{\sum_i w_i A_i \bigl(Y_i - \hat\mu_1(X_i)\bigr)} {\sum_i w_i A_i} - \frac{\sum_i w_i (1 - A_i) \bigl(Y_i - \hat\mu_0(X_i)\bigr)} {\sum_i w_i (1 - A_i)}.

The first term is the outcome-model-based regression estimate of the ATE. The second and third terms are weighted bias corrections that use the balancing weights to account for errors in the outcome models.

This has a key robustness property: the estimator is consistent if either the kernel (balancing weights) or the outcome models are correctly specified. When both are reasonable, the augmented estimator typically has lower bias and RMSE than either approach alone.

Basic usage

Enable augmentation with augmented = TRUE:

set.seed(123)
dat <- simulate_data(n = 2000, p = 20, ate = 0, dgp = 2)

# Standard (weighting only)
fit <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 500)

# Augmented (doubly-robust)
fit_aug <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 500,
                           augmented = TRUE)

c("Standard" = fit$ate, "Augmented" = fit_aug$ate, "Truth" = 0)
#>   Standard  Augmented      Truth 
#> 0.01664814 0.06350369 0.00000000

The print method indicates when augmentation is active:

fit_aug
#> Forest Kernel Energy Balancing (cross-fitted)
#> -------------------------------------------------- 
#>   n = 2,000  (n_treated = 1042, n_control = 958)
#>   Trees: 500
#>   Cross-fitting: 2 folds
#>   Augmented: yes (doubly-robust)
#>   Solver: direct
#>   ATE estimate: 0.0635
#>   Fold ATEs: 0.0898, 0.0372
#>   ESS: treated = 795/1042 (76%)   control = 712/958 (74%)
#> -------------------------------------------------- 
#> Use summary() for covariate balance details.

How it works with cross-fitting

When both cross.fitting = TRUE (the default) and augmented = TRUE, the outcome model is cross-fitted in lockstep with the kernel:

For each fold kk:

  1. Train the joint forest (for the kernel) on folds k-k.
  2. Train two outcome regression forests on folds k-k: one on treated observations (μ̂1\hat\mu_1) and one on control observations (μ̂0\hat\mu_0).
  3. Predict leaf nodes, μ̂1\hat\mu_1, and μ̂0\hat\mu_0 for fold kk out-of-sample.
  4. Compute the fold-level doubly-robust ATE using the cross-fitted weights and outcome predictions.

This ensures that neither the kernel nor the outcome models overfit to the estimation sample. The user does not need to manage the cross-fitting of μ̂1\hat\mu_1 and μ̂0\hat\mu_0 separately — it is handled automatically.

Simulation comparison

We compare four configurations using DGP 2, which has a rich outcome surface with moderate confounding (n=5,000n = 5{,}000, p=20p = 20):

set.seed(123)
nreps <- 5
n <- 5000; p <- 20

methods <- c("Standard", "CF only", "Aug only", "CF + Aug")
res <- matrix(NA, nreps, length(methods), dimnames = list(NULL, methods))

for (r in seq_len(nreps)) {
  dat <- simulate_data(n = n, p = p, ate = 0, dgp = 2)

  res[r, "Standard"] <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 500,
                                        cross.fitting = FALSE,
                                        min.node.size = 10)$ate
  res[r, "CF only"]  <- forest_balance(dat$X, dat$A, dat$Y,
                                        num.trees = 500)$ate
  res[r, "Aug only"] <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 500,
                                        cross.fitting = FALSE,
                                        augmented = TRUE,
                                        min.node.size = 10)$ate
  res[r, "CF + Aug"] <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 500,
                                        augmented = TRUE)$ate
}
DGP 2: n = 5,000, p = 20, true ATE = 0, 5 reps.
Method Bias SD RMSE
Standard 0.1204 0.0437 0.1281
CF only 0.0151 0.0673 0.0690
Aug only 0.0515 0.0535 0.0743
CF + Aug 0.0055 0.0609 0.0611

The combination of cross-fitting and augmentation (CF + Aug) achieves the lowest bias. Cross-fitting alone handles the kernel overfitting, while augmentation accounts for outcome variation that the kernel does not fully capture. DGP 2 has a rich outcome surface depending on 8 covariates, of which only 2 are confounders, making the outcome model particularly valuable.

User-supplied outcome predictions

You can supply your own group-specific outcome predictions via mu.hat as a list with components mu1 (predictions of E[Y|X,A=1]E[Y|X, A{=}1]) and mu0 (predictions of E[Y|X,A=0]E[Y|X, A{=}0]):

set.seed(123)
dat <- simulate_data(n = 1000, p = 20, ate = 0, dgp = 2)

# Fit your own group-specific outcome models
df_all <- data.frame(Y = dat$Y, dat$X)
df1 <- df_all[dat$A == 1, ]
df0 <- df_all[dat$A == 0, ]
lm1 <- lm(Y ~ ., data = df1)
lm0 <- lm(Y ~ ., data = df0)
mu_custom <- list(
  mu1 = predict(lm1, newdata = df_all),
  mu0 = predict(lm0, newdata = df_all)
)

fit_custom <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 500,
                              mu.hat = mu_custom)
fit_custom
#> Forest Kernel Energy Balancing (cross-fitted)
#> -------------------------------------------------- 
#>   n = 1,000  (n_treated = 505, n_control = 495)
#>   Trees: 500
#>   Cross-fitting: 2 folds
#>   Augmented: yes (doubly-robust)
#>   Solver: direct
#>   ATE estimate: 0.1576
#>   Fold ATEs: 0.2724, 0.0428
#>   ESS: treated = 391/505 (77%)   control = 380/495 (77%)
#> -------------------------------------------------- 
#> Use summary() for covariate balance details.

When mu.hat is supplied with cross.fitting = TRUE, the package uses the provided predictions as-is and emits a message reminding you to ensure they were cross-fitted externally.

When to use augmentation

Augmentation is most beneficial when:

  • The confounding is nonlinear and the kernel alone cannot fully capture it.
  • The outcome model is reasonably accurate, even if not perfectly specified.
  • You want robustness against partial misspecification of either the kernel or the outcome model.

The computational cost of augmentation is modest: fitting two additional regression_forest models (one per treatment group) per fold, which is typically fast relative to the joint forest and kernel construction.

References

Robins, J.M., Rotnitzky, A. and Zhao, L.P. (1994). Estimation of regression coefficients when some regressors are not always observed. Journal of the American Statistical Association, 89(427), 846–866.

De, S. and Huling, J.D. (2025). Data adaptive covariate balancing for causal effect estimation for high dimensional data. arXiv preprint arXiv:2512.18069.