Skip to contents

Why scalar summaries

metahunt() predicts a function on the grid; conformal routines can return a band at every grid point. Often, though, the inferential target is a single number derived from that function:

  • The average treatment effect (mean of a CATE function over a reference patient distribution).
  • The treatment effect at a specific patient profile.
  • The fraction of the population with a positive treatment effect.
  • A contrast between two endpoints.

For all of these, MetaHunt accepts a wrapper argument that collapses the predicted function to a scalar before any further calculation. The same wrapper is applied identically to predictions and to calibration residuals, so conformal coverage transfers directly to the scalar summary.

The wrapper protocol

apply_wrapper(F_mat, wrapper, grid_weights) defines the contract.

  • F_mat is an n-by-G_grid numeric matrix; row j is one function on the grid.
  • If wrapper is NULL, apply_wrapper() returns the weighted mean of each row using grid_weights (uniform 1/G_grid by default), divided by sum(grid_weights).
  • If wrapper is a function, apply_wrapper() calls apply(F_mat, 1, wrapper), which means the wrapper receives a single numeric vector of length G_grid — one row of F_mat at a time — and must return a single numeric value.

The contract therefore is:

wrapper :: numeric vector of length G_grid  ->  numeric scalar

Any function satisfying that signature is a valid wrapper. The package then enforces post-hoc that the result is numeric and has exactly one entry per row.

An ATE example with grf::causal_forest

We simulate a multi-site clinical trial with m = 8 sites. Each site has its own individual-level data (Y,X,T)(Y, X, T) where YY is a continuous outcome, XX is a single patient covariate (age), and TT is binary treatment. The site-level CATE function τ(i)(age)=E[Y(1)Y(0)age,site=i]\tau^{(i)}(\text{age}) = E[Y(1) - Y(0) \mid \text{age}, \text{site} = i] varies across sites in a way that depends on the site’s metadata. Each site fits its own grf::causal_forest on its individual-level data, and shares only the fitted model — not the patient data — with us.

m <- 8
n_per_site <- 200
G <- 30

W <- data.frame(
  year        = sample(2010:2020, m, replace = TRUE),
  pct_treated = round(runif(m, 0.3, 0.6), 2)
)

site_data_list <- lapply(seq_len(m), function(i) {
  age <- runif(n_per_site, 30, 80)
  T   <- rbinom(n_per_site, 1, W$pct_treated[i])
  site_eff <- (W$year[i] - 2015) / 5   # site-level shift in CATE
  tau_age  <- 0.02 * (age - 50) + site_eff
  Y0  <- 0.01 * age + rnorm(n_per_site, sd = 0.5)
  Y1  <- Y0 + tau_age
  Y   <- ifelse(T == 1, Y1, Y0)
  data.frame(Y = Y, age = age, T = T)
})

grid <- data.frame(age = seq(30, 80, length.out = G))

Each site fits its own causal_forest. We use num.trees = 200 to keep the vignette fast; in practice you would use the default 2000 or more.

cf_models <- lapply(site_data_list, function(d)
  grf::causal_forest(X = matrix(d$age, ncol = 1),
                     Y = d$Y,
                     W = d$T,
                     num.trees = 200))

We stack the per-site CATE estimates on the shared age grid into the m-by-G matrix F_hat. Here we pass an explicit predict_fn to illustrate the general pattern; the dispatch table inside f_hat_from_models() already knows how to call causal_forest, so for users on standard grf::causal_forest, the default predict_fn is sufficient and you can omit the predict_fn argument.

cate_predict <- function(model, grid) {
  as.numeric(stats::predict(model, newdata = matrix(grid$age, ncol = 1))$predictions)
}
F_hat <- f_hat_from_models(cf_models, grid, predict_fn = cate_predict)
dim(F_hat)
#> [1]  8 30

We now fit metahunt() on (F_hat, W) and ask for the predicted ATE at a hypothetical new site.

fit <- metahunt(F_hat, W, K = 3, dfspa_args = list(denoise = FALSE))
W_new <- data.frame(year = 2018, pct_treated = 0.45)
ate_pred <- predict(fit, newdata = W_new, wrapper = mean)
ate_pred
#> [1] 0.9247137

The scalar ate_pred is the predicted average treatment effect for a hypothetical new site with metadata (year = 2018, pct_treated = 0.45), taking the unweighted mean over the 30-point age grid.

Three custom wrappers

Below are three short, self-contained wrappers, each illustrating a different idea. All three are applied to the F_hat, fit, and W_new constructed in the previous section.

Plain mean

mean is already a function numeric -> numeric, so it is a valid wrapper. With a uniform grid this is just the unweighted average of the function over the grid — i.e. the grid-uniform ATE.

predict(fit, newdata = W_new, wrapper = mean)
#> [1] 0.9247137

Restricted positive mean

Suppose we only credit treatment effects that are positive (for example, in a cost-effectiveness setting). The wrapper averages max(f(x), 0) over the grid:

restricted_pos_mean <- function(f) sum(pmax(f, 0)) / length(f)
predict(fit, newdata = W_new, wrapper = restricted_pos_mean)
#> [1] 0.9247137

Because every row of F_mat is passed in turn, f inside the wrapper is just a numeric vector of length G_grid. length(f) is therefore the grid size, and dividing by it gives a uniform-weighted average.

Endpoint contrast

The difference f(x_G) - f(x_1) is a useful summary when the grid is ordered (e.g. age, dose, or time). For our age grid it is the gap in CATE between an 80-year-old and a 30-year-old patient at the new site:

endpoint_contrast <- function(f) f[length(f)] - f[1]
predict(fit, newdata = W_new, wrapper = endpoint_contrast)
#> [1] 0.7503231

Conformal coverage with a wrapper

When you pass wrapper into split_conformal() (or cross_conformal(), or conformal_from_fit()), conformity scores are computed after the wrapper, on a single shared quantile. The interval covers the wrapped scalar with the nominal level — not the underlying function pointwise.

With only m = 8 sites, we hold out a single site (the 8th) and use the other seven for training plus calibration. The calibration set is small, so we use alpha = 0.1 rather than 0.05.

# Use 7 sites for training+calibration, predict for the held-out 8th
tr_cal <- 1:7; new <- 8
res <- split_conformal(
  F_hat[tr_cal, , drop = FALSE],
  W[tr_cal, , drop = FALSE],
  W[new, , drop = FALSE],
  K = 3, wrapper = mean, alpha = 0.1, cal_frac = 0.5, seed = 1,
  dfspa_args = list(denoise = FALSE)
)
#> Warning in .build_conformal_output(obs_cal = F_cal, pred_cal = pred_cal, : With
#> n_cal = 3 and alpha = 0.1, the conformal quantile is infinite; intervals are
#> unbounded. Increase calibration size or use a larger `alpha`.
data.frame(prediction = res$prediction,
           lower      = res$lower,
           upper      = res$upper)
#>   prediction lower upper
#> 1  -0.830792  -Inf   Inf

With only 8 sites in this realistic example, an empirical-coverage check on a single held-out site is not informative — for coverage diagnostics, use a leave-one-out loop or simulate a larger study count. See ?coverage for the helper function and the conformal-prediction vignette for split-conformal at scale.

Pointwise vs scalar — quick reference

Aspect Pointwise (wrapper = NULL) Scalar (wrapper supplied)
Output shape nrow(W_new) x G_grid matrix length-nrow(W_new) numeric vector
Conformal quantile one per grid point (length-G_grid) a single scalar
Coverage guarantee per grid point, marginally (not joint over grid) for the scalar summary, marginally
Best for visualising the predicted function with a band reporting a single number with a valid CI
Example call split_conformal(F, W, W_new, K = 3) split_conformal(F, W, W_new, K = 3, wrapper = mean)

A pointwise band is a visualisation aid; a scalar interval is the right object for an inferential claim about a specific functional. Pick the wrapper that matches the question you actually want to answer, and let the conformal machinery do the rest.

See also