import os
import sys
import sys, os # for running from repo
sys.path.insert(0, os.path.abspath("../../"))
try: from brmspy import brms; import seaborn;
except ImportError:
%pip install -q brmspy seaborn
from brmspy import brms
from brmspy.brms import set_rescor, bf, lf
import pandas as pd
from arviz_stats import loo, compare
from arviz_plots import plot_ppc_dist
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
renv = "multivariate-models"
if not brms.environment_exists(renv):
with brms.manage(environment_name=renv) as ctx:
ctx.install_runtime()
ctx.install_rpackage("MCMCglmm")
else:
brms.environment_activate(renv)
[brmspy][_call_with_frames_removed] Running autoload!
Introduction¶
In the present example, we want to discuss how to specify multivariate multilevel models using brms. We call a model multivariate if it contains multiple response variables, each being predicted by its own set of predictors. Consider an example from biology. Hadfield, Nutall, Osorio, and Owens (2007) analyzed data of the Eurasian blue tit (https://en.wikipedia.org/wiki/Eurasian_blue_tit). They predicted the tarsus length as well as the back color of chicks. Half of the brood were put into another fosternest, while the other half stayed in the fosternest of their own dam. This allows to separate genetic from environmental factors. Additionally, we have information about the hatchdate and sex of the chicks (the latter being known for 94% of the animals).
df = brms.get_data("BTdata", package = "MCMCglmm")
df.head()
| tarsus | back | animal | dam | fosternest | hatchdate | sex | |
|---|---|---|---|---|---|---|---|
| 1 | -1.892297 | 1.146421 | R187142 | R187557 | F2102 | -0.687402 | Fem |
| 2 | 1.136110 | -0.759652 | R187154 | R187559 | F1902 | -0.687402 | Male |
| 3 | 0.984689 | 0.144937 | R187341 | R187568 | A602 | -0.427981 | Male |
| 4 | 0.379008 | 0.255585 | R046169 | R187518 | A1302 | -1.465664 | Male |
| 5 | -0.075253 | -0.300699 | R046161 | R187528 | A2602 | -1.465664 | Fem |
Basic Multivariate Models¶
We begin with a relatively simple multivariate normal model.
bform1 = bf("""
mvbind(tarsus, back) ~
sex +
hatchdate +
(1|p|fosternest) +
(1|q|dam)
""") + set_rescor(rescor=True)
fit1 = brms.brm(bform1, data = df, chains = 2, cores = 2, silent = 2, refresh = 0)
[brmspy][worker_main] Fitting model with brms (backend: cmdstanr)... [brmspy][worker_main] Fit done!
As can be seen in the model code, we have used mvbind notation to tell brms that both tarsus and back are separate response variables. The term (1|p|fosternest) indicates a varying intercept over fosternest. By writing |p| in between we indicate that all varying effects of fosternest should be modeled as correlated. This makes sense since we actually have two model parts, one for tarsus and one for back. The indicator p is arbitrary and can be replaced by other symbols that comes into your mind (for details about the multilevel syntax of brms, see help("brmsformula") and vignette("brms_multilevel")). Similarly, the term (1|q|dam) indicates correlated varying effects of the genetic mother of the chicks. Alternatively, we could have also modeled the genetic similarities through pedigrees and corresponding relatedness matrices, but this is not the focus of this vignette (please see vignette("brms_phylogenetics")). The model results are readily summarized via
for var in fit1.idata.posterior_predictive.data_vars:
print(var)
print(loo(fit1.idata, var_name=var))
print("\n")
tarsus
/Users/sebastian/PycharmProjects/pybrms/.venv/lib/python3.12/site-packages/arviz_stats/loo/helper_loo.py:1143: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. warnings.warn(
Computed from 2000 posterior samples and 828 observations log-likelihood matrix.
Estimate SE
elpd_loo -997.61 27.21
p_loo 99.08 -
There has been a warning during the calculation. Please check the results.
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 827 99.9%
(0.70, 1] (bad) 1 0.1%
(1, Inf) (very bad) 0 0.0%
back
Computed from 2000 posterior samples and 828 observations log-likelihood matrix.
Estimate SE
elpd_loo -1128.65 19.43
p_loo 75.27 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 828 100.0%
(0.70, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
brms.summary(fit1)
Family: MV(gaussian, gaussian)
Links: mu = identity
mu = identity
Formula: tarsus ~ sex + hatchdate + (1 | p | fosternest) + (1 | q | dam)
back ~ sex + hatchdate + (1 | p | fosternest) + (1 | q | dam)
Data: structure(list(tarsus = c(-1.89229718155107, 1.136 (Number of observations: 828)
Draws: 2 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 2000
Multilevel Hyperparameters:
~dam (Number of levels: 106)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept) 0.48 0.05 0.39 0.59 1.00
sd(back_Intercept) 0.25 0.08 0.10 0.39 1.01
cor(tarsus_Intercept,back_Intercept) -0.52 0.22 -0.94 -0.08 1.01
Bulk_ESS Tail_ESS
sd(tarsus_Intercept) 781 1347
sd(back_Intercept) 353 773
cor(tarsus_Intercept,back_Intercept) 510 554
~fosternest (Number of levels: 104)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept) 0.27 0.05 0.17 0.37 1.00
sd(back_Intercept) 0.35 0.06 0.24 0.46 1.01
cor(tarsus_Intercept,back_Intercept) 0.70 0.20 0.24 0.98 1.00
Bulk_ESS Tail_ESS
sd(tarsus_Intercept) 771 1063
sd(back_Intercept) 717 1200
cor(tarsus_Intercept,back_Intercept) 244 515
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
tarsus_Intercept -0.41 0.07 -0.54 -0.27 1.00 1429 1466
back_Intercept -0.01 0.06 -0.13 0.11 1.00 2737 1504
tarsus_sexMale 0.77 0.06 0.66 0.88 1.00 3933 1549
tarsus_sexUNK 0.23 0.12 -0.02 0.48 1.00 3764 1699
tarsus_hatchdate -0.04 0.06 -0.16 0.07 1.00 1365 1424
back_sexMale 0.01 0.06 -0.12 0.13 1.00 4861 1644
back_sexUNK 0.15 0.14 -0.13 0.43 1.01 3372 1626
back_hatchdate -0.09 0.05 -0.19 0.02 1.00 2593 1661
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma_tarsus 0.76 0.02 0.72 0.80 1.00 3610 1520
sigma_back 0.90 0.02 0.86 0.95 1.00 2453 1378
Residual Correlations:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
rescor(tarsus,back) -0.05 0.04 -0.13 0.02 1.00 2992 1400
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
The summary output of multivariate models closely resembles those of univariate models, except that the parameters now have the corresponding response variable as prefix. Across dams, tarsus length and back color seem to be negatively correlated, while across fosternests the opposite is true. This indicates differential effects of genetic and environmental factors on these two characteristics. Further, the small residual correlation rescor(tarsus, back) on the bottom of the output indicates that there is little unmodeled dependency between tarsus length and back color. Although not necessary at this point, we have already computed and stored the LOO information criterion of fit1, which we will use for model comparisons. Next, let’s take a look at some posterior-predictive checks, which give us a first impression of the model fit.
plot_ppc_dist(fit1.idata, var_names=['tarsus'])
<arviz_plots.plot_collection.PlotCollection at 0x11897b8c0>
plot_ppc_dist(fit1.idata, var_names=["back"])
<arviz_plots.plot_collection.PlotCollection at 0x11cb8ae70>
This looks pretty solid, but we notice a slight unmodeled left skewness in the distribution of tarsus. We will come back to this later on. Next, we want to investigate how much variation in the response variables can be explained by our model and we use a Bayesian generalization of the 𝑅2 coefficient.
brms.call("bayes_R2", fit1)
| Estimate | Est.Error | Q2.5 | Q97.5 | |
|---|---|---|---|---|
| R2tarsus | 0.434311 | 0.198187 | 0.023093 | 0.027391 |
| R2back | 0.386106 | 0.141256 | 0.477110 | 0.249358 |
Clearly, there is much variation in both animal characteristics that we can not explain, but apparently we can explain more of the variation in tarsus length than in back color.
More Complex Multivariate Models¶
Now, suppose we only want to control for sex in tarsus but not in back and vice versa for hatchdate. Not that this is particular reasonable for the present example, but it allows us to illustrate how to specify different formulas for different response variables. We can no longer use mvbind syntax and so we have to use a more verbose approach:
bf_tarsus = bf("tarsus ~ sex + (1|p|fosternest) + (1|q|dam)")
bf_back = bf("back ~ hatchdate + (1|p|fosternest) + (1|q|dam)")
fit2 = brms.brm(bf_tarsus + bf_back + set_rescor(True), data = df, chains = 2, cores = 2, silent = 2, refresh = 0)
[brmspy][worker_main] Fitting model with brms (backend: cmdstanr)... [brmspy][worker_main] Fit done!
Note that we have literally added the two model parts via the + operator, which is in this case equivalent to writing mvbf(bf_tarsus, bf_back). See help("brmsformula") and help("mvbrmsformula") for more details about this syntax. Again, we summarize the model first.
for var in fit2.idata.posterior_predictive.data_vars:
print(var)
print(loo(fit2.idata, var_name=var))
print("\n")
tarsus
/Users/sebastian/PycharmProjects/pybrms/.venv/lib/python3.12/site-packages/arviz_stats/loo/helper_loo.py:1143: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. warnings.warn(
Computed from 2000 posterior samples and 828 observations log-likelihood matrix.
Estimate SE
elpd_loo -997.68 27.22
p_loo 99.60 -
There has been a warning during the calculation. Please check the results.
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 827 99.9%
(0.70, 1] (bad) 1 0.1%
(1, Inf) (very bad) 0 0.0%
back
Computed from 2000 posterior samples and 828 observations log-likelihood matrix.
Estimate SE
elpd_loo -1126.79 19.35
p_loo 72.72 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 828 100.0%
(0.70, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
brms.summary(fit2)
Family: MV(gaussian, gaussian)
Links: mu = identity
mu = identity
Formula: tarsus ~ sex + (1 | p | fosternest) + (1 | q | dam)
back ~ hatchdate + (1 | p | fosternest) + (1 | q | dam)
Data: structure(list(tarsus = c(-1.89229718155107, 1.136 (Number of observations: 828)
Draws: 2 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 2000
Multilevel Hyperparameters:
~dam (Number of levels: 106)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept) 0.48 0.05 0.39 0.58 1.00
sd(back_Intercept) 0.25 0.07 0.09 0.38 1.01
cor(tarsus_Intercept,back_Intercept) -0.52 0.23 -0.93 -0.08 1.01
Bulk_ESS Tail_ESS
sd(tarsus_Intercept) 814 1016
sd(back_Intercept) 301 559
cor(tarsus_Intercept,back_Intercept) 554 757
~fosternest (Number of levels: 104)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept) 0.27 0.05 0.17 0.38 1.00
sd(back_Intercept) 0.35 0.06 0.23 0.46 1.00
cor(tarsus_Intercept,back_Intercept) 0.68 0.21 0.20 0.98 1.00
Bulk_ESS Tail_ESS
sd(tarsus_Intercept) 657 1183
sd(back_Intercept) 463 841
cor(tarsus_Intercept,back_Intercept) 248 412
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
tarsus_Intercept -0.41 0.07 -0.54 -0.28 1.00 1161 1417
back_Intercept -0.00 0.05 -0.10 0.10 1.00 1500 1356
tarsus_sexMale 0.77 0.06 0.66 0.89 1.00 3130 1584
tarsus_sexUNK 0.23 0.12 -0.00 0.48 1.00 3087 1693
back_hatchdate -0.08 0.05 -0.19 0.02 1.00 1763 1323
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma_tarsus 0.76 0.02 0.72 0.80 1.00 2403 1577
sigma_back 0.90 0.02 0.86 0.95 1.00 2630 1690
Residual Correlations:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
rescor(tarsus,back) -0.05 0.04 -0.12 0.02 1.00 2172 1607
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Let’s find out, how model fit changed due to excluding certain effects from the initial model:
var = "back"
loo1 = loo(fit1.idata, var_name=var)
loo2 = loo(fit2.idata, var_name=var)
cmp = compare({"m1": fit1.idata, "m2": fit2.idata}, var_name=var)
cmp
| rank | elpd | p | elpd_diff | weight | se | dse | warning | |
|---|---|---|---|---|---|---|---|---|
| m2 | 0 | -1130.0 | 72.7 | 0.0 | 1.0 | 19.0 | 0.0 | False |
| m1 | 1 | -1130.0 | 75.3 | 2.0 | 0.0 | 19.0 | 1.1 | False |
Apparently, there is no noteworthy difference in the model fit. Accordingly, we do not really need to model sex and hatchdate for both response variables, but there is also no harm in including them (so I would probably just include them).
To give you a glimpse of the capabilities of brms’ multivariate syntax, we change our model in various directions at the same time. Remember the slight left skewness of tarsus, which we will now model by using the skew_normal family instead of the gaussian family. Since we do not have a multivariate normal (or student-t) model, anymore, estimating residual correlations is no longer possible. We make this explicit using the set_rescor function. Further, we investigate if the relationship of back and hatchdate is really linear as previously assumed by fitting a non-linear spline of hatchdate. On top of it, we model separate residual variances of tarsus for male and female chicks.
from brmspy.brms import skew_normal, gaussian
bf_tarsus = bf("tarsus ~ sex + (1|p|fosternest) + (1|q|dam)") + lf("sigma ~ 0 + sex") + skew_normal()
bf_back = bf("back ~ s(hatchdate) + (1|p|fosternest) + (1|q|dam)") + gaussian()
fit3 = brms.brm(
bf_tarsus + bf_back + set_rescor(False),
data = df, chains = 2, cores = 2,
control = {"adapt_delta": 0.95},
silent = 2, refresh = 0
)
[brmspy][worker_main] Fitting model with brms (backend: cmdstanr)... [brmspy][worker_main] Fit done!
Again, we summarize the model and look at some posterior-predictive checks.
for var in fit3.idata.posterior_predictive.data_vars:
print(var)
print(loo(fit3.idata, var_name=var))
print("\n")
tarsus
/Users/sebastian/PycharmProjects/pybrms/.venv/lib/python3.12/site-packages/arviz_stats/loo/helper_loo.py:1143: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. warnings.warn(
Computed from 2000 posterior samples and 828 observations log-likelihood matrix.
Estimate SE
elpd_loo -998.94 28.04
p_loo 105.18 -
There has been a warning during the calculation. Please check the results.
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 824 99.5%
(0.70, 1] (bad) 4 0.5%
(1, Inf) (very bad) 0 0.0%
back
Computed from 2000 posterior samples and 828 observations log-likelihood matrix.
Estimate SE
elpd_loo -1125.71 19.41
p_loo 70.09 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 828 100.0%
(0.70, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
brms.summary(fit3)
Family: MV(skew_normal, gaussian)
Links: mu = identity; sigma = log
mu = identity
Formula: tarsus ~ sex + (1 | p | fosternest) + (1 | q | dam)
sigma ~ 0 + sex
back ~ s(hatchdate) + (1 | p | fosternest) + (1 | q | dam)
Data: structure(list(tarsus = c(-1.89229718155107, 1.136 (Number of observations: 828)
Draws: 2 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 2000
Smoothing Spline Hyperparameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
sds(back_shatchdate_1) 2.07 1.10 0.32 4.75 1.00 378
Tail_ESS
sds(back_shatchdate_1) 427
Multilevel Hyperparameters:
~dam (Number of levels: 106)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept) 0.47 0.05 0.38 0.57 1.00
sd(back_Intercept) 0.23 0.07 0.09 0.37 1.00
cor(tarsus_Intercept,back_Intercept) -0.52 0.23 -0.96 -0.06 1.00
Bulk_ESS Tail_ESS
sd(tarsus_Intercept) 564 1115
sd(back_Intercept) 266 351
cor(tarsus_Intercept,back_Intercept) 336 308
~fosternest (Number of levels: 104)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept) 0.26 0.06 0.16 0.38 1.01
sd(back_Intercept) 0.31 0.06 0.20 0.43 1.01
cor(tarsus_Intercept,back_Intercept) 0.62 0.22 0.14 0.97 1.00
Bulk_ESS Tail_ESS
sd(tarsus_Intercept) 468 916
sd(back_Intercept) 400 834
cor(tarsus_Intercept,back_Intercept) 249 394
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
tarsus_Intercept -0.41 0.07 -0.55 -0.28 1.00 852
back_Intercept -0.00 0.05 -0.10 0.10 1.00 1189
tarsus_sexMale 0.77 0.06 0.65 0.89 1.00 2183
tarsus_sexUNK 0.22 0.12 -0.02 0.46 1.00 2038
sigma_tarsus_sexFem -0.30 0.04 -0.38 -0.22 1.00 1869
sigma_tarsus_sexMale -0.25 0.04 -0.32 -0.16 1.00 1875
sigma_tarsus_sexUNK -0.40 0.13 -0.65 -0.14 1.00 1436
back_shatchdate_1 -0.31 3.21 -5.99 6.96 1.00 634
Tail_ESS
tarsus_Intercept 1304
back_Intercept 1508
tarsus_sexMale 1405
tarsus_sexUNK 1437
sigma_tarsus_sexFem 1427
sigma_tarsus_sexMale 1575
sigma_tarsus_sexUNK 1207
back_shatchdate_1 798
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma_back 0.90 0.02 0.85 0.95 1.00 1709 1456
alpha_tarsus -1.22 0.44 -1.87 0.16 1.00 924 411
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
We see that the (log) residual standard deviation of tarsus is somewhat larger for chicks whose sex could not be identified as compared to male or female chicks. Further, we see from the negative alpha (skewness) parameter of tarsus that the residuals are indeed slightly left-skewed. Lastly, running
result = brms.call("conditional_effects", fit3, "hatchdate", resp="back")
df_result = result['back.back_hatchdate']
df_plot = df_result.sort_values("hatchdate")
fig, ax = plt.subplots()
ax.plot(
df_plot["hatchdate"],
df_plot["estimate__"],
color="blue"
)
ax.fill_between(
df_plot["hatchdate"],
df_plot["lower__"],
df_plot["upper__"],
alpha=0.3,
color="blue"
)
ax.set_xlabel("hatchdate")
ax.set_ylabel("back")
ax.set_title("Conditional effect of hatchdate on back")
plt.show()
reveals a non-linear relationship of hatchdate on the back color, which seems to change in waves over the course of the hatch dates.
There are many more modeling options for multivariate models, which are not discussed in this vignette. Examples include autocorrelation structures, Gaussian processes, or explicit non-linear predictors (e.g., see help("brmsformula") or vignette("brms_multilevel")). In fact, nearly all the flexibility of univariate models is retained in multivariate models.
References¶
Hadfield JD, Nutall A, Osorio D, Owens IPF (2007). Testing the phenotypic gambit: phenotypic, genetic and environmental correlations of colour. Journal of Evolutionary Biology, 20(2), 549-557.