Fit a weight model mapping study-level covariates to simplex weights
Source:R/weight_estimation.R
fit_weight_model.RdGiven a matrix of simplex-valued weights \(\hat\pi_1,\ldots,\hat\pi_m\)
(e.g. from project_to_simplex()) and associated study-level covariates
\(\mathbf W_1,\ldots,\mathbf W_m\), fit a model
\(\widehat{\mathcal M}:\mathbf W \mapsto \boldsymbol\pi\).
The default method is Dirichlet regression via the DirichletReg package.
Usage
fit_weight_model(
pi_hat,
W,
method = c("dirichlet"),
boundary_eps = 1e-04,
formula = NULL,
...
)Arguments
- pi_hat
An
m-by-Knumeric matrix of simplex weights; rows must be non-negative and sum to 1 (up to tolerance1e-6).- W
An
m-by-pmatrix or data frame of study-level covariates.- method
Weight-model method. Currently only
"dirichlet"is supported.- boundary_eps
Small positive scalar used to shrink weights away from the simplex boundary before Dirichlet fitting. Defaults to
1e-4.- formula
Optional RHS-only formula (e.g.
~ x1 + I(x2^2)) describing the covariate part of the Dirichlet regression. Defaults to~ .(all columns ofW).- ...
Passed through to
DirichletReg::DirichReg().
Value
An object of class "metahunt_weight_model": a list with the
fitted model, formula, method, K, and training covariate names.
Details
Dirichlet regression cannot handle weights exactly at the simplex boundary
(0 or 1), which frequently arise after constrained projection. Before
fitting, rows of pi_hat are shrunk toward the barycenter via
\(\tilde\pi = (\pi + \varepsilon) / (1 + K\varepsilon)\), with
\(\varepsilon\) set by boundary_eps.
Examples
# \donttest{
set.seed(1)
m <- 80; K <- 3; p <- 2
W <- matrix(rnorm(m * p), m, p); colnames(W) <- c("w1", "w2")
# generate simplex weights driven by W
eta <- cbind(0.5 * W[, 1], -0.3 * W[, 2], rep(0, m))
pi_true <- exp(eta) / rowSums(exp(eta))
pi_hat <- pi_true + matrix(rnorm(m * K, sd = 0.01), m, K)
pi_hat <- pmax(pi_hat, 0); pi_hat <- pi_hat / rowSums(pi_hat)
model <- fit_weight_model(pi_hat, W)
predict(model, newdata = matrix(c(0, 0), 1, 2, dimnames = list(NULL, c("w1","w2"))))
#> [,1] [,2] [,3]
#> [1,] 0.3328722 0.3340615 0.3330662
# }