Kernel energy balancing weights via closed-form solution
Source:R/kernel_balance.R
kernel_balance.RdComputes balancing weights that minimize a kernelized energy distance between the weighted treated and control distributions and the overall sample. The weights are obtained via a closed-form solution to a linear system derived from the kernel energy distance objective.
Usage
kernel_balance(
trt,
kern = NULL,
Z = NULL,
num.trees = NULL,
solver = c("auto", "direct", "cg"),
tol = 1e-08,
maxiter = 2000L
)Arguments
- trt
A binary (0/1) integer or numeric vector indicating treatment assignment (
1= treated,0= control).- kern
A symmetric \(n \times n\) kernel matrix (dense or sparse), or
NULLifZis provided.- Z
Optional sparse indicator matrix from
leaf_node_kernel_Zsuch that \(K = Z Z^\top / B\). When supplied, the solver can avoid forming the full kernel matrix. If bothkernandZare given,Ztakes priority when the CG solver is selected.- num.trees
Number of trees \(B\). Required when
Zis provided.- solver
Which linear solver to use.
"auto"(default) selects"direct"for \(n \le 5000\) and"cg"for \(n > 5000\)."direct"uses sparse Cholesky on the treated and control sub-blocks of the kernel."cg"uses conjugate gradient iterations with the factored \(Z\) representation, avoiding formation of any kernel matrix.- tol
Convergence tolerance for the CG solver. Default is
1e-8. Ignored whensolver = "direct".- maxiter
Maximum CG iterations. Default is 1000.
Value
A list with the following elements:
- weights
A numeric vector of length \(n\) containing the balancing weights. Treated weights sum to \(n_1\) and control weights sum to \(n_0\).
- solver
The solver that was used (
"direct"or"cg").
Details
The modified kernel \(K_q\) used in the optimization is block-diagonal: the treated–control cross-blocks are zero because \(K_q(i,j) = 0\) whenever \(A_i \neq A_j\). Both solvers exploit this structure by working on the treated and control blocks independently.
The direct solver extracts the sub-blocks \(K_{tt}\) and \(K_{cc}\) and solves via sparse Cholesky. This gives exact solutions but requires forming (at least sub-blocks of) the kernel matrix.
The CG solver uses the factored representation \(K = Z Z^\top / B\)
to perform matrix–vector products without forming any kernel matrix,
via \(K v = Z (Z^\top v) / B\). This is much faster and more
memory-efficient at large \(n\) (e.g., \(n > 5000\)). The CG iterates
converge to the exact solution; the default tolerance of 5e-11 yields
weight vectors that agree with the direct solution to several decimal places.
References
De, S. and Huling, J.D. (2025). Data adaptive covariate balancing for causal effect estimation for high dimensional data. arXiv preprint arXiv:2512.18069.
Examples
# \donttest{
library(grf)
n <- 200
p <- 5
X <- matrix(rnorm(n * p), n, p)
A <- rbinom(n, 1, plogis(X[, 1]))
Y <- X[, 1] + rnorm(n)
forest <- multi_regression_forest(X, cbind(A, Y), num.trees = 500)
K <- forest_kernel(forest)
bal <- kernel_balance(A, K)
# Weighted ATE estimate
w <- bal$weights
ate <- weighted.mean(Y[A == 1], w[A == 1]) -
weighted.mean(Y[A == 0], w[A == 0])
# }