Skip to contents

Fits a multivariate random forest that jointly models the relationship between covariates, treatment, and outcome, computes a random forest proximity kernel, and then uses kernel energy balancing to produce weights for estimating the average treatment effect (ATE). By default, K-fold cross-fitting is used to avoid overfitting bias from estimating the kernel on the same data used for treatment effect estimation.

Usage

forest_balance(
  X,
  A,
  Y,
  num.trees = 1000,
  min.node.size = NULL,
  cross.fitting = TRUE,
  num.folds = 2,
  augmented = FALSE,
  mu.hat = NULL,
  scale.outcomes = TRUE,
  solver = c("auto", "direct", "cg"),
  tol = 1e-08,
  parallel = FALSE,
  ...
)

Arguments

X

A numeric matrix or data frame of covariates (\(n \times p\)).

A

A binary (0/1) vector of treatment assignments.

Y

A numeric vector of outcomes.

num.trees

Number of trees to grow in the forest. Default is 1000.

min.node.size

Minimum number of observations per leaf node. If NULL (default), an adaptive heuristic is used: max(20, min(floor(n/200) + p, floor(n/50))). This scales the leaf size with both the sample size and the number of covariates, which empirically yields low bias. See Details.

cross.fitting

Logical; if TRUE (default), use K-fold cross-fitting to construct the kernel from held-out data, reducing overfitting bias. If FALSE, the kernel is estimated on the full sample.

num.folds

Number of cross-fitting folds. Default is 2. Only used when cross.fitting = TRUE.

augmented

Logical; if TRUE, use an augmented (doubly-robust) estimator that combines the kernel energy balancing weights with group-specific outcome regression models. This reduces bias when either the kernel or the outcome models are correctly specified. Default is FALSE. See Details.

mu.hat

Optional list with components mu1 and mu0, each a numeric vector of length \(n\), containing user-supplied predictions of \(E[Y \mid X, A=1]\) and \(E[Y \mid X, A=0]\). When provided, these are used instead of fitting internal outcome models. If NULL (default) and augmented = TRUE, two regression_forest models are fit automatically (one on treated, one on control). When supplying mu.hat with cross.fitting = TRUE, the user is responsible for ensuring the predictions were cross-fitted externally.

scale.outcomes

If TRUE (default), the joint outcome matrix cbind(A, Y) is column-standardized before fitting the forest. This ensures that treatment and outcome contribute equally to the splits.

solver

Which linear solver to use for the balancing weights. "auto" (default) selects "direct" for small fold sizes and "cg" for large fold sizes. See kernel_balance for details.

tol

Convergence tolerance for the CG solver. Default is 5e-11.

parallel

Logical or integer. If FALSE (default), folds are processed sequentially. If TRUE, folds are processed in parallel using all available cores via mclapply. An integer value specifies the exact number of cores. Only used when cross.fitting = TRUE. Note: parallel processing is not supported on Windows.

...

Additional arguments passed to multi_regression_forest.

Value

An object of class "forest_balance" (a list) with the following elements:

ate

The estimated average treatment effect. When cross-fitting is used, this is the average of per-fold Hajek estimates (DML1).

weights

The balancing weight vector (length \(n\)). When cross-fitting is used, these are the concatenated per-fold weights.

mu1.hat

Predictions of \(E[Y|X, A=1]\) (length \(n\)), or NULL if augmented = FALSE.

mu0.hat

Predictions of \(E[Y|X, A=0]\) (length \(n\)), or NULL if augmented = FALSE.

kernel

The \(n \times n\) forest proximity kernel (sparse matrix), or NULL when cross-fitting or the CG solver is used.

forest

The trained forest object. When cross-fitting is used, this is the last fold's forest.

X, A, Y

The input data.

n, n1, n0

Total, treated, and control sample sizes.

solver

The solver that was used ("direct" or "cg").

crossfit

Logical indicating whether cross-fitting was used.

augmented

Logical indicating whether augmentation was used.

num.folds

Number of folds (if cross-fitting was used).

fold_ates

Per-fold ATE estimates (if cross-fitting was used).

fold_ids

Fold assignments (if cross-fitting was used).

The object has print and summary methods. Use summary.forest_balance for covariate balance diagnostics.

Details

The method proceeds in three steps:

  1. A multi_regression_forest is fit on covariates X with a bivariate response (A, Y). This jointly models the relationship between covariates, treatment assignment, and outcome.

  2. The forest's leaf co-membership structure defines a proximity kernel: \(K(i,j)\) is the proportion of trees where \(i\) and \(j\) share a leaf. Because the forest splits on both \(A\) and \(Y\), this kernel captures confounding structure.

  3. kernel_balance computes balancing weights via the closed-form kernel energy distance solution. The ATE is then estimated using the Hajek (ratio) estimator with these weights.

Cross-fitting (default): For each fold \(k\), the forest is trained on all data except fold \(k\), and the kernel for fold \(k\) is built from that held-out forest's leaf predictions. This breaks the dependence between the kernel and the outcomes, reducing overfitting bias. The final ATE is the average of the per-fold Hajek estimates (DML1).

Augmented estimator: When augmented = TRUE, two group-specific outcome models \(\hat\mu_1(X) = E[Y|X, A=1]\) and \(\hat\mu_0(X) = E[Y|X, A=0]\) are fit, and the ATE is estimated via the doubly-robust formula: $$\hat\tau = \frac{1}{n}\sum_i [\hat\mu_1(X_i) - \hat\mu_0(X_i)] + \frac{\sum w_i A_i (Y_i - \hat\mu_1(X_i))}{\sum w_i A_i} - \frac{\sum w_i (1-A_i)(Y_i - \hat\mu_0(X_i))}{\sum w_i (1-A_i)}.$$ The first term is the regression-based estimate of the ATE; the remaining terms are weighted bias corrections. This is consistent if either the kernel (balancing weights) or the outcome models are correctly specified. When combined with cross-fitting, the outcome models are automatically cross-fitted in lockstep with the kernel.

Adaptive leaf size: The default min.node.size is set adaptively via max(20, min(floor(n/200) + p, floor(n/50))). Larger leaves produce smoother kernels that generalize better, while the cap at n/50 prevents kernel degeneracy.

References

Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W. and Robins, J. (2018). Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21(1), C1–C68.

De, S. and Huling, J.D. (2025). Data adaptive covariate balancing for causal effect estimation for high dimensional data. arXiv preprint arXiv:2512.18069.

Examples

# \donttest{
n <- 500
p <- 10
X <- matrix(rnorm(n * p), n, p)
A <- rbinom(n, 1, plogis(0.5 * X[, 1]))
Y <- X[, 1] + rnorm(n)  # true ATE = 0

# Default: cross-fitting with adaptive leaf size
result <- forest_balance(X, A, Y)
result
#> Forest Kernel Energy Balancing (cross-fitted)
#> -------------------------------------------------- 
#>   n = 500  (n_treated = 237, n_control = 263)
#>   Trees: 1000
#>   Cross-fitting: 2 folds
#>   Solver: cg
#>   ATE estimate: 0.0576
#>   Fold ATEs: 0.0499, 0.0652
#>   ESS: treated = 172/237 (73%)   control = 198/263 (75%)
#> -------------------------------------------------- 
#> Use summary() for covariate balance details.

# Augmented (doubly-robust) estimator
result_aug <- forest_balance(X, A, Y, augmented = TRUE)

# Without cross-fitting
result_nocf <- forest_balance(X, A, Y, cross.fitting = FALSE)
# }