JAGSModel R6 Class
JAGSModel.RdThis class allows the user to fit and use a mixed effects model with
intercepts that vary by country, facility and output type as well as fixed
covariate effects. It is used for the comparison of different model
structures as in the vignette("03_model-comparisons", package = "capturetb")
vignette, and is the class underlying the unitcost, unitcost_fixed
and unitcost_ohd models.
Methods
Method new()
Initialize a new model instance.
Usage
JAGSModel$new(dat, covariates, target, priors = NULL)Arguments
datData.frame. Training data
covariatesCharacter vector. Names of covariate columns.
targetCharacter. Name of the target variable.
priorsList of class "capturetbpriors". Should be created using
capturetb_priors. If NULL, non-informative priors will be used.
Method fit()
Fit the model using JAGS. Requires JAGS and runjags to be installed.
Usage
JAGSModel$fit(
n.chains = 3,
n.iter = 1e+06,
n.burnin = 5000,
n.adapt = 5000,
n.thin = 100,
seed = NULL,
...
)Arguments
n.chainsInteger. Number of MCMC chains. Default is 3.
n.iterInteger. Number of total iterations per chain. Default is 1000000.
n.burninInteger. Number of burn-in iterations. Default is 5000.
n.adaptInteger. Number of adaptation iterations. Default is 5000.
n.thinInteger. Thinning interval. Default is 100.
seedOptonal integer. Used to seed both the R and JAGS random generators for reproducible results.
...Additional arguments passed to runjags::run.jags.
Method predict()
Generate predictions from the fitted model.
Usage
JAGSModel$predict(
dat,
scale = "log",
summarised = FALSE,
centrality = "mean",
ci = 0.95,
ci_type = "predictive",
test = NULL,
...
)Arguments
datNew input data for predictions. This should be prepared for the model using prepare_covariates.
scaleOne of "log" or "natural". Default "log".
summarisedLogical. If TRUE, summarises predictions using bayestestR::describe_posterior. See bayestestR::describe_posterior for full documentation of available arguments. Default FALSE.
centralityThe point-estimates (centrality indices) to compute. Default "mean".
ciValue or vector of probability of the credible interval (between 0 and 1) to be estimated. Default
0.95(95%).ci_typeOne of "mean" or "predictive". Specify whether you want the credible interval around the mean prediction, or the predictive interval. Default "predictive".
testThe indices of effect existence to compute. Default NULL. See bayestestR::describe_posterior for options.
...Other arguments that will be passed to bayestestR::describe_posterior.
ci_methodThe type of index used for the Credible Interval. You probably want either ETI (Equal Tailed Interval) or HDI (Highest Density Interval). Default ETI. See bayestestR::describe_posterior for all options.
Method covariate_correlation()
View correlations between covariates in the training data
Arguments
plotLogical. If TRUE, return a ggplot2::ggplot object. If FALSE return a correlation matrix. Default TRUE.
Returns
ggplot2::ggplot object or correlation matrix
Method mcmc_trace()
Create trace plots for MCMC chains using bayesplot::mcmc_trace.
Arguments
...Additional arguments passed to bayesplot::mcmc_trace.
Returns
A ggplot2::ggplot object showing trace plots.
Method mcmc_rhat()
Compute and plot R-hat convergence diagnostics.
Returns
A ggplot2::ggplot object showing R-hat diagnostics.
Method mcmc_acf()
Create autocorrelation plots for MCMC chains using bayesplot::mcmc_acf.
Arguments
...Additional arguments passed to bayesplot::mcmc_acf.
Returns
A ggplot2::ggplot object showing autocorrelation plots.
Method n_eff()
Computes the effective sample size of the posterior samples using the coda::effectiveSize function.
Method plot_posteriors()
Plot posterior distributions using bayesplot::mcmc_areas.
Arguments
probNumeric. Density to highlight. Default 0.9.
...Additional arguments passed to bayesplot::mcmc_areas.
Returns
A ggplot2::ggplot object showing posterior distributions.
Method performance()
Calculate model performance metrics on known data
This method evaluates the fitted model's performance by comparing predictions to known costs and computing mean absolute error (MAE), root mean square error (RMSE), Bayesian R2, credible interval coverage and median credible interval width. By default the model training data is used, but a different dataset can also be provided.
Usage
JAGSModel$performance(
scale = "natural",
conditional = FALSE,
by_country = FALSE,
dat = NULL
)Arguments
scaleOne of "log" or "natural". Default "log".
conditionalLogical. If TRUE, returns conditional performance. If FALSE, returns performance marginalised over facility random effects. Default FALSE.
by_countryLogical. If TRUE, returns metrics calculated on country sub-groups. Default FALSE.
datOptional data prepared using
prepare_covariates(). If provided, uses this data for performance calculation instead of the training data.
Returns
A data.frame with performance metrics:
country: Only present if by_country = TRUE
mae: Mean Absolute Error between observed and predicted values
rmse: Root Mean Square Error between observed and predicted values
bayesian_r2: Mean Bayesian R2 estimate.
ci_coverage: Proportion of observations within 95% credible intervals
median_ci: The median width of 95% credible intervals
Examples
model <- unitcost()
model$performance()Method plot_residuals()
Create a residual plot for for diagnosing model fit.
This method generates a diagnostic plot showing residuals (observed minus predicted values) against fitted values. Residuals are on the log scale as the model is fitted on a log scale. The plot includes a reference line at zero, a LOESS smooth curve to identify patterns, and points colored by country.
Arguments
add_smoothLogical. Whether to add a LOESS smooth curve to identify patterns in residuals. Default TRUE.
color_by_countryLogical. Whether to color points by country. Default TRUE.
Returns
A ggplot2::ggplot object showing residuals vs fitted values.
Method plot_fit()
Create a scatter plot of observed vs predicted values to assess fit.
This method generates a diagnostic plot showing the relationship between observed and predicted values on the training data, with a reference line for perfect predictions and optional confidence intervals.
Usage
JAGSModel$plot_fit(
scale = "log",
conditional = FALSE,
include_ci = TRUE,
color_by_country = TRUE
)Arguments
scaleOne of "log" or "natural". Default "log".
conditionalLogical. If TRUE, shows full conditional fit. If FALSE, shows marginal fit. Default FALSE.
include_ciLogical. Whether to show predictive intervals as error bars. Default TRUE.
color_by_countryLogical. Whether to color points by country. Default TRUE.
Returns
A ggplot2::ggplot object showing observed vs predicted values.
Method k_fold_cv()
Perform k-fold cross-validation.
Method leave_one_country_out()
Perform leave-one-country-out cross-validation.
Method fitted_parameters()
Extract fitted model parameters with credible intervals.
This method summarises fitted parameters using bayestestR::describe_posterior. See bayestestR::describe_posterior for full documentation of available argument.
Usage
JAGSModel$fitted_parameters(
centrality = "mean",
ci = 0.95,
ci_method = "eti",
test = NULL,
...
)Arguments
centralityThe point-estimates (centrality indices) to compute. Default "mean".
ciValue or vector of probability of the CI (between 0 and 1) to be estimated. Default
0.95(95%).ci_methodThe type of index used for Credible Interval. Default ETI.
testThe indices of effect existence to compute. Default NULL.
...Other arguments that will be passed to bayestestR::describe_posterior.
Method mcmc_DIC()
Retrieve penalized deviance statistics.
This method returns cached penalized deviance statistics created at the time of model fitting using.
Arguments
summarisedLogical. If TRUE (default) return the total DIC as a single numeric value. If FALSE, return all DIC samples.
Returns
If summarised = TRUE, a numeric scalar of the total DIC.
If summarised = FALSE, an object of class "dic"; see rjags::dic.samples().
Examples
mod <- unitcost()
mod$mcmc_DIC()Examples
## ------------------------------------------------
## Method `JAGSModel$performance`
## ------------------------------------------------
model <- unitcost()
#> Multiple outputs detected. Including output-level random effects in model.
model$performance()
#> mae rmse ci_coverage median_ci bayesian_r2
#> 1 4.780022 6.933826 0.9672727 23.36332 0.4920227
## ------------------------------------------------
## Method `JAGSModel$plot_fit`
## ------------------------------------------------
if (FALSE) { # \dontrun{
model <- JAGSModel$new()
model$fit()
p <- model$plot_fit()
print(p)
# Natural scale without confidence intervals
p2 <- model$plot_fit(scale = "natural", include_ci = FALSE)
print(p2)
} # }
## ------------------------------------------------
## Method `JAGSModel$fitted_parameters`
## ------------------------------------------------
model <- unitcost()
#> Multiple outputs detected. Including output-level random effects in model.
params <- model$fitted_parameters()
print(params)
#> Summary of Posterior Distribution
#>
#> Parameter | Mean | 95% CI
#> ----------------------------------------------
#> alpha | 2.10 | [ 1.86, 2.35]
#> beta[1] | -0.36 | [-0.47, -0.24]
#> beta[2] | 0.22 | [ 0.06, 0.38]
#> beta[3] | 0.17 | [-0.04, 0.37]
#> beta[4] | 0.35 | [ 0.12, 0.59]
#> beta[5] | 0.12 | [-0.20, 0.45]
#> beta[6] | 0.12 | [ 0.04, 0.20]
#> beta[7] | 0.11 | [ 0.03, 0.19]
#> beta[8] | -0.22 | [-0.31, -0.13]
#> beta[9] | -0.29 | [-0.40, -0.17]
#> sigma | 0.35 | [ 0.33, 0.38]
#> sigma_c | 0.10 | [ 0.00, 0.28]
#> country_effect[1] | -0.02 | [-0.21, 0.16]
#> country_effect[2] | 0.10 | [-0.05, 0.37]
#> country_effect[3] | -0.08 | [-0.34, 0.06]
#> country_effect[4] | 6.07e-03 | [-0.17, 0.19]
#> country_effect[5] | -3.24e-03 | [-0.20, 0.18]
#> sigma_f | 0.41 | [ 0.35, 0.48]
#> sigma_v | 0.16 | [ 0.09, 0.28]
#> output_effect[1] | -0.06 | [-0.30, 0.17]
#> output_effect[2] | -0.21 | [-0.36, -0.07]
#> output_effect[3] | 0.20 | [ 0.08, 0.32]
#> output_effect[4] | 0.03 | [-0.24, 0.32]
#> output_effect[5] | 0.13 | [-0.01, 0.28]
#> output_effect[6] | 0.04 | [-0.07, 0.16]
#> output_effect[7] | 0.06 | [-0.18, 0.33]
#> output_effect[8] | -0.02 | [-0.16, 0.12]
#> output_effect[9] | -4.16e-03 | [-0.14, 0.13]
#> output_effect[10] | 0.08 | [-0.06, 0.22]
#> output_effect[11] | 0.12 | [-0.11, 0.38]
#> output_effect[12] | -0.10 | [-0.22, 0.03]
#> output_effect[13] | -0.17 | [-0.45, 0.06]
#> output_effect[14] | -0.10 | [-0.31, 0.09]
# 90% credible intervals
params_90 <- model$fitted_parameters(ci = 0.9)
## ------------------------------------------------
## Method `JAGSModel$mcmc_DIC`
## ------------------------------------------------
mod <- unitcost()
#> Multiple outputs detected. Including output-level random effects in model.
mod$mcmc_DIC()
#> [1] 527.3794