brmspy - Variational Inference¶
In [ ]:
Copied!
import os
import sys
import sys, os # for running from repo
sys.path.insert(0, os.path.abspath(".."))
try: from brmspy import brms
except ImportError:
%pip install -q brmspy
from brmspy import brms
from brmspy import bf, set_rescor, lf
import pandas as pd
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
#brms.install_runtime()
import os
import sys
import sys, os # for running from repo
sys.path.insert(0, os.path.abspath(".."))
try: from brmspy import brms
except ImportError:
%pip install -q brmspy
from brmspy import brms
from brmspy import bf, set_rescor, lf
import pandas as pd
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
#brms.install_runtime()
R callback write-console: Error in loadNamespace(x) : there is no package called ‘cmdstanr’ R callback write-console: CmdStan path set to: /Users/sebastian/.brmspy/runtime/macos-arm64-r4.5-0.1.0/cmdstan
In [2]:
Copied!
df = brms.get_brms_data("epilepsy")
df = brms.get_brms_data("epilepsy")
In [6]:
Copied!
model = brms.brm(
formula="count ~ zAge + zBase * Trt + (1|patient)",
family="negbinomial",
data=df,
# 1. ALGORITHM
algorithm="meanfield",
# 2. OPTIMIZATION LOOP
iter=30000,
eta=0.1,
adapt_engaged=True,
# 3. GRADIENT ESTIMATION (Reducing Noise)
grad_samples=5,
elbo_samples=100,
# 4. CONVERGENCE CRITERIA
eval_elbo=200,
tol_rel_obj=0.0001,
seed=42,
refresh=0, silent=2
)
model = brms.brm(
formula="count ~ zAge + zBase * Trt + (1|patient)",
family="negbinomial",
data=df,
# 1. ALGORITHM
algorithm="meanfield",
# 2. OPTIMIZATION LOOP
iter=30000,
eta=0.1,
adapt_engaged=True,
# 3. GRADIENT ESTIMATION (Reducing Noise)
grad_samples=5,
elbo_samples=100,
# 4. CONVERGENCE CRITERIA
eval_elbo=200,
tol_rel_obj=0.0001,
seed=42,
refresh=0, silent=2
)
[brmspy] Fitting model with brms (backend: cmdstanr)...
------------------------------------------------------------ EXPERIMENTAL ALGORITHM: This procedure has not been thoroughly tested and may be unstable or buggy. The interface is subject to change. ------------------------------------------------------------ Gradient evaluation took 6.6e-05 seconds 1000 transitions using 10 leapfrog steps per transition would take 0.66 seconds. Adjust your expectations accordingly! Begin eta adaptation. Iteration: 1 / 250 [ 0%] (Adaptation) Iteration: 50 / 250 [ 20%] (Adaptation) Iteration: 100 / 250 [ 40%] (Adaptation) Iteration: 150 / 250 [ 60%] (Adaptation) Iteration: 200 / 250 [ 80%] (Adaptation) Success! Found best value [eta = 1] earlier than expected. Begin stochastic gradient ascent. iter ELBO delta_ELBO_mean delta_ELBO_med notes 200 -652.394 1.000 1.000 400 -652.970 0.500 1.000 600 -650.120 0.335 0.004 800 -652.771 0.252 0.004 1000 -650.305 0.203 0.004 1200 -649.219 0.169 0.004 1400 -651.298 0.145 0.004 1600 -650.483 0.127 0.004 1800 -652.597 0.114 0.003 2000 -650.026 0.103 0.004 2200 -650.406 0.093 0.003 2400 -650.830 0.086 0.003 2600 -650.437 0.079 0.003 2800 -648.942 0.074 0.003 3000 -650.946 0.069 0.003 3200 -650.071 0.002 0.002 3400 -649.236 0.002 0.002 3600 -651.772 0.002 0.002 3800 -650.013 0.002 0.002 4000 -649.779 0.002 0.002 4200 -650.599 0.002 0.001 4400 -649.986 0.002 0.001 4600 -649.596 0.002 0.001 4800 -650.784 0.002 0.001 5000 -651.083 0.001 0.001 5200 -648.887 0.002 0.001 5400 -649.172 0.002 0.001 5600 -649.446 0.002 0.001 5800 -648.973 0.002 0.001 6000 -648.896 0.001 0.001 6200 -649.314 0.001 0.001 6400 -649.089 0.001 0.001 6600 -649.375 0.001 0.001 6800 -649.641 0.001 0.000 7000 -649.083 0.001 0.001 7200 -648.838 0.001 0.000 7400 -648.887 0.001 0.000 7600 -648.659 0.001 0.000 7800 -648.117 0.001 0.000 8000 -650.031 0.001 0.000 8200 -649.920 0.001 0.000 8400 -649.288 0.001 0.000 8600 -648.947 0.001 0.000 8800 -647.518 0.001 0.000 9000 -649.226 0.001 0.001 9200 -650.370 0.001 0.001 9400 -648.902 0.001 0.001 9600 -648.854 0.001 0.001 9800 -650.103 0.001 0.001 10000 -649.133 0.001 0.001 10200 -648.926 0.001 0.001 10400 -648.198 0.001 0.001 10600 -650.328 0.002 0.001 10800 -648.939 0.002 0.002 11000 -649.499 0.001 0.001 11200 -648.618 0.002 0.001 11400 -648.918 0.001 0.001 11600 -648.060 0.002 0.001 11800 -649.188 0.002 0.001 12000 -649.402 0.001 0.001 12200 -649.857 0.001 0.001 12400 -650.381 0.001 0.001 12600 -649.190 0.001 0.001 12800 -649.261 0.001 0.001 13000 -649.857 0.001 0.001 13200 -649.480 0.001 0.001 13400 -648.946 0.001 0.001 13600 -647.745 0.001 0.001 13800 -649.635 0.001 0.001 14000 -649.332 0.001 0.001 14200 -649.666 0.001 0.001 14400 -649.317 0.001 0.001 14600 -649.858 0.001 0.001 14800 -649.478 0.001 0.001 15000 -650.213 0.001 0.001 15200 -649.809 0.001 0.001 15400 -649.419 0.001 0.001 15600 -650.002 0.001 0.001 15800 -647.672 0.001 0.001 16000 -649.595 0.001 0.001 16200 -649.908 0.001 0.001 16400 -650.165 0.001 0.001 16600 -649.076 0.001 0.001 16800 -649.755 0.001 0.001 17000 -649.846 0.001 0.001 17200 -648.930 0.001 0.001 17400 -649.366 0.001 0.001 17600 -649.002 0.001 0.001 17800 -650.382 0.001 0.001 18000 -650.079 0.001 0.001 18200 -650.501 0.001 0.001 18400 -648.699 0.001 0.001 18600 -647.671 0.001 0.001 18800 -649.414 0.001 0.001 19000 -648.623 0.001 0.001 19200 -649.478 0.001 0.001 19400 -649.675 0.001 0.001 19600 -648.411 0.001 0.001 19800 -648.159 0.001 0.001 20000 -649.112 0.001 0.001 20200 -649.459 0.001 0.001 20400 -647.534 0.001 0.001 20600 -649.228 0.002 0.001 20800 -649.454 0.001 0.001 21000 -649.684 0.001 0.001 21200 -648.718 0.001 0.001 21400 -650.052 0.001 0.001 21600 -649.061 0.001 0.001 21800 -649.027 0.001 0.001 22000 -649.068 0.001 0.001 22200 -648.792 0.001 0.001 22400 -649.876 0.001 0.001 22600 -649.991 0.001 0.001 22800 -649.223 0.001 0.001 23000 -649.077 0.001 0.001 23200 -648.953 0.001 0.000 23400 -649.894 0.001 0.000 23600 -648.379 0.001 0.000 23800 -648.999 0.001 0.001 24000 -650.932 0.001 0.001 24200 -648.653 0.001 0.001 24400 -649.225 0.001 0.001 24600 -649.829 0.001 0.001 24800 -649.086 0.001 0.001 25000 -648.186 0.001 0.001 25200 -649.377 0.001 0.001 25400 -647.628 0.001 0.001 25600 -649.409 0.002 0.001 25800 -649.125 0.002 0.001 26000 -648.632 0.002 0.001 26200 -648.135 0.002 0.001 26400 -649.139 0.002 0.001 26600 -649.310 0.002 0.001 26800 -649.611 0.001 0.001 27000 -648.477 0.001 0.001 27200 -650.008 0.001 0.001 27400 -648.483 0.001 0.001 27600 -648.559 0.001 0.001 27800 -649.814 0.001 0.002 28000 -648.641 0.001 0.002 28200 -649.483 0.001 0.002 28400 -648.510 0.001 0.002 28600 -649.441 0.001 0.001 28800 -649.177 0.001 0.001 29000 -649.217 0.001 0.001 29200 -649.179 0.001 0.001 29400 -648.250 0.001 0.001 29600 -649.271 0.001 0.001 29800 -648.726 0.001 0.001 30000 -649.375 0.001 0.001 Informational Message: The maximum number of iterations is reached! The algorithm may not have converged. This variational approximation is not guaranteed to be meaningful. Drawing a sample of size 1000 from the approximate posterior... COMPLETED. Finished in 3.5 seconds.
In [7]:
Copied!
loo_res = az.loo(model.idata, pointwise=True)
print(loo_res)
# Visualize the k-hats
az.plot_khat(loo_res)
plt.show()
loo_res = az.loo(model.idata, pointwise=True)
print(loo_res)
# Visualize the k-hats
az.plot_khat(loo_res)
plt.show()
Computed from 1000 posterior samples and 236 observations log-likelihood matrix.
Estimate SE
elpd_loo -623.39 17.86
p_loo 45.52 -
There has been a warning during the calculation. Please check the results.
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.67] (good) 227 96.2%
(0.67, 1] (bad) 7 3.0%
(1, Inf) (very bad) 2 0.8%
/Users/sebastian/PycharmProjects/pybrms/.venv/lib/python3.12/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.67 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. warnings.warn(
In [8]:
Copied!
az.plot_ppc(model.idata, num_pp_samples=100)
plt.title("Posterior Predictive Check")
plt.xlim(0, 100)
plt.show()
az.plot_ppc(model.idata, num_pp_samples=100)
plt.title("Posterior Predictive Check")
plt.xlim(0, 100)
plt.show()
In [9]:
Copied!
pred = brms.posterior_predict(model)
pred = brms.posterior_predict(model)
In [ ]:
Copied!
summary = brms.summary(model)
summary
summary = brms.summary(model)
summary
[brmspy][iterate_robject_to_dataclass][WARNING] Type of param 'iter' <class 'int'> does not match expected '<class 'float'>'
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS Intercept 2.04 0.12 1.81 2.29 1.00 1113.57 1022.38 zAge 0.10 0.04 0.02 0.18 1.00 1163.52 979.43 zBase 0.65 0.03 0.58 0.71 1.00 926.45 907.70 Trt -0.25 0.07 -0.40 -0.10 1.00 1065.05 980.16 zBase:Trt 0.04 0.02 0.00 0.08 1.00 948.40 940.45
In [ ]:
Copied!