Awesome-Agent-Skills-for-Empirical-Research bayesian-workflow
install
source · Clone the upstream repo
git clone https://github.com/brycewang-stanford/Awesome-Agent-Skills-for-Empirical-Research
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/brycewang-stanford/Awesome-Agent-Skills-for-Empirical-Research "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/23-Learning-Bayesian-Statistics-baygent-skills/bayesian-workflow" ~/.claude/skills/brycewang-stanford-awesome-agent-skills-for-empirical-research-bayesian-workflow && rm -rf "$T"
manifest:
skills/23-Learning-Bayesian-Statistics-baygent-skills/bayesian-workflow/SKILL.mdsource content
Bayesian Workflow
Workflow overview
Every Bayesian analysis follows this sequence. Do not skip steps -- especially model criticism.
- Formulate — Define the generative story. What underlying process, that we're precisely trying to model, created the data?
- Specify priors — See references/priors.md
- Implement in PyMC — Write the model. Prefer PyMC 5+ syntax. Use the latest version possible.
- Run prior predictive checks —
. Verify priors produce plausible data ranges before fittingpm.sample_prior_predictive() - Inference —
. Always use nutpie for speed (the nutpie python package provides cutting-edge sampling). Don't hardcode the number of chains — let the sampler pick the best default for the platform.pm.sample(nuts_sampler="nutpie") - Diagnose convergence — Use
as the first check (requires arviz-stats >= 1.0.0). It covers R-hat, ESS, divergences, tree depth, and E-BFMI in one call. See references/diagnostics.mdarviz_stats.diagnose(idata) - Criticize the model — See references/model-criticism.md
- Check prior sensitivity — Run
to verify conclusions are robust to prior choices. Visualize withpsense_summary(idata)
fromplot_psense_dist(idata)
. Requiresarviz_plots
andlog_likelihood
in the InferenceData — compute them after sampling if needed. See references/sensitivity.mdlog_prior - Compare models (if applicable) — See references/model-comparison.md
- Report results — See references/reporting.md. When the user asks for a report or mentions a non-technical audience, generate a standalone markdown report file (not just code comments) using the template in reporting.md. Adapt the language to the audience — if they're new to Bayesian stats, include a glossary and plain-language explanations of key concepts.
Installation
Prefer conda-forge / mamba-forge to install PyMC and its dependencies — pip can cause issues with compiled backends (nutpie, JAX). Example:
mamba install -c conda-forge pymc nutpie arviz arviz-stats preliz
PyMC model template
import pymc as pm import arviz as az import numpy as np RANDOM_SEED = sum(map(ord, "churn-logistic-v1")) rng = np.random.default_rng(RANDOM_SEED) # always use dimensions and coordinates in PyMC models with pm.Model(coords=coords) as model: # use Data containers when working on a PyMC model data = pm.Data("data", df["y"].to_numpy(), dims="obs") # --- Priors --- # Always document WHY each prior was chosen mu = pm.Normal("mu", mu=0, sigma=10) # Weakly informative: allows wide range # --- Data model --- pm.Normal("obs", mu=mu, sigma=1, observed=data, dims="obs") # --- Prior predictive check --- prior_pred = pm.sample_prior_predictive(random_seed=rng) # --- Inference --- idata = pm.sample(nuts_sampler="nutpie", random_seed=rng) idata.extend(prior_pred) # --- Posterior predictive check --- idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng)) # --- Compute log-likelihood and log-prior for sensitivity checks & LOO --- pm.compute_log_likelihood(idata, model=model) pm.compute_log_prior(idata, model=model) # --- Save immediately after sampling --- # Late crashes can destroy valid results. Save to disk before any post-processing. idata.to_netcdf("model_output.nc")
Critical rules
- Always run prior predictive checks before sampling. If prior predictions span implausible ranges, fix priors first. If you have issues or doubts for some parameters, use the PreliZ package to elicit priors from the user.
- Always check convergence before interpreting results. R-hat > 1.01 or ESS < 100 * nbr_chains means the results are unreliable.
- Always run posterior predictive checks. A model that fits well numerically but cannot reproduce the data is useless.
- Always run calibration checks (PIT / coverage). Use ArviZ's
for this — it handles all data types (continuous, binary, count) correctly. See references/model-criticism.md.plot_ppc_pit - Document every prior choice with a brief justification in a code comment.
- Never report point estimates alone. Always include credible intervals (default: 94% HDI).
- Use
as the first diagnostic on every model (arviz-stats >= 1.0.0). It checks R-hat, ESS, divergences, tree depth saturation, and E-BFMI in one call. Follow up witharviz_stats.diagnose(idata)
for visual inspection.az.plot_trace(idata, kind="rank_vlines") - Don't hardcode number of chains. Let PyMC / nutpie choose the optimal default for the user's platform. Just call
without specifyingpm.sample()
.chains - Use reproducible, descriptive seeds. Never use magic numbers like
. Instead, derive a seed from the analysis name:42
. Pass it toRANDOM_SEED = sum(map(ord, "my-analysis-name"))
,pm.sample(random_seed=rng)
, and numpy viapm.sample_prior_predictive(random_seed=rng)
.rng = np.random.default_rng(RANDOM_SEED) - Save InferenceData immediately after sampling with
. Late crashes or kernel restarts can destroy valid MCMC results — save before any post-processing.idata.to_netcdf("model_output.nc") - Use ArviZ for all plots and calibration. Don't write custom plotting code when ArviZ already handles it — including for binary data, count data, and calibration. ArviZ developers have thought through edge cases so you don't have to.
- Prefer xarray over numpy for InferenceData operations.
andInferenceData
objects are backed by xarray — use xarray's labeled indexing (DataTree
,.sel()
, etc.) instead of converting to numpy arrays. This preserves dimension labels, avoids shape bugs, and makes code more readable. Fall back to numpy only when xarray can't do what you need..mean(dim=...) - Always generate analysis notes alongside code. When producing a model script, also produce a companion markdown file (
or similar) that interprets the results — what the diagnostics mean, what the posteriors tell us, what the calibration plots show. Code without interpretation is incomplete.analysis_notes.md - Always use the posterior mean (not median) for predictive probabilities. The proper Bayesian predictive distribution averages over the posterior:
. This is the mean, not the median. The median does not correspond to the posterior predictive distribution, can violate probability coherence (probabilities may not sum to 1), and biases calibration due to Jensen's inequality. In code: useP(Y=k|x) = (1/S) Σ P(Y=k|x,θₛ)
, nevernp.mean(probs, axis=sample_axis)
.np.median(...) - Use
+pm.set_data()
for out-of-sample predictions. Don't manually extract posterior samples and recompute predictions — let PyMC propagate uncertainty properly. Define predictors aspm.sample_posterior_predictive()
during model building, then swap in new data:pm.Data(...)
# After fitting the model: with model: pm.set_data({"X": X_new, "group_idx": group_idx_new}) oos_preds = pm.sample_posterior_predictive(idata, predictions=True, random_seed=rng)
- Check model identifiability before interpreting components. If two model components always appear together in the likelihood (e.g., a league intercept and a home advantage term when every observation is from home perspective), their individual posteriors reflect prior assumptions, not data signal — only their sum is identified. Use
to check for strong posterior correlations between components. If correlation is near ±1, the components are not separately identifiable — either merge them or restructure the data.az.plot_pair()
Common model families
| Problem | Data model | Typical priors | Reference |
|---|---|---|---|
| Continuous outcome | Normal / StudentT | Normal, Gamma avoiding 0 for positive-constrained parameters | references/priors.md |
| Binary outcome | Bernoulli or Binomial if aggregated, with logit inverse-link | Normal(0, 1.5) on coeffs | references/priors.md |
| Count data | Poisson / NegBinomial | Gamma on rate, avoiding 0 | references/priors.md |
| Count data with excess zeros | ZeroInflatedPoisson / ZeroInflatedNegBinomial | Gamma on rate; Beta or Normal+logit on zero-inflation prob | references/priors.md |
| Positive count data (no zeros) | Hurdle Poisson / Hurdle NegBinomial | Separate zero-gate (Bernoulli) and count (Truncated) components | references/priors.md |
| Ordinal outcome | OrderedLogistic (cumulative link) | Normal on coeffs; Normal with ordered transform on cutpoints | references/priors.md |
| Censored data (survival, limits of detection) | | Same as uncensored, applied to underlying distribution | references/priors.md |
| Truncated data | | Same as underlying distribution | references/priors.md |
| High-dimensional / sparse regression | Normal / StudentT with sparsity prior on coefficients | Regularized Horseshoe or R2-D2 on coeffs | references/priors.md |
| Hierarchical / multilevel | Varies | See partial pooling pattern | references/hierarchical.md |
| Time series | state space models / Gaussian Processes | Problem-specific | references/priors.md |
Utility scripts
Run
diagnose_model.py after sampling to get a structured convergence + diagnostics report:
python scripts/diagnose_model.py --idata path/to/inference_data.nc
Run
calibration_check.py to generate calibration plots:
python scripts/calibration_check.py --idata path/to/inference_data.nc
See scripts/ for all available utilities.
Common gotchas
These are battle-tested lessons that save hours of debugging:
- nutpie silently ignores
foridata_kwargs
andlog_likelihood
. Always compute them explicitly after sampling:log_prior
(needed for LOO-CV) andpm.compute_log_likelihood(idata, model=model)
(needed for prior sensitivity checks). Don't assume they're stored automatically.pm.compute_log_prior(idata, model=model)
requires the LOO object, not InferenceData. Pass the output ofaz.plot_khat()
to it.az.loo(idata, pointwise=True)- Flat priors on scale parameters (
,HalfCauchy
) cause funnels in hierarchical models. UseHalfFlat
orGamma(2, ...)
— these avoid the near-zero region that creates sampling problems. If there's no group-level variation to detect, you don't need the hierarchy.Exponential - Python conditionals in models (
) don't work inside PyMC. Useif x > 0
orpm.math.switch
instead.pytensor.tensor.where - Forgetting to standardize predictors makes shared priors inappropriate and slows sampling. Always standardize before fitting, then back-transform for interpretation.
- Horseshoe priors create a double-funnel geometry that standard NUTS can struggle with. Always use the regularized (Finnish) horseshoe (Piironen & Vehtari, 2017), which adds a slab component that smooths the geometry. Set
or higher. If you see divergences with a horseshoe model, this is almost certainly the cause.target_accept=0.95
on posterior predictive probabilities is a silent bug. It does not produce the Bayesian predictive distribution and can yield probabilities that don't sum to 1 across categories. Always usenp.median
over the posterior samples dimension.np.mean
When things go wrong
| Symptom | Likely cause | Fix |
|---|---|---|
| Divergences | Posterior geometry issue | Reparameterize (non-centered), increase to 0.95-0.99 |
| Low ESS | High autocorrelation | More tuning steps, reparameterize, reduce correlations |
| R-hat > 1.01 | Chains haven't mixed | More draws, better initialization, check for multimodality |
| Prior pred. looks wrong | Bad priors | Tighten or shift priors, use domain knowledge |
| Post. pred. misses data | Model misspecification | Add complexity (varying slopes, different data model, interaction terms) |
missing | nutpie doesn't auto-store it | Call after sampling |
| Slow model | Large Deterministics or recompilation | Profile with , avoid large arrays |
| Slow to initialize / poor warmup | Bad starting point | Try in , or run first () and pass its estimates as |
| Prior sensitivity flag | Prior-data conflict or strong prior | Check — see references/sensitivity.md. Justify or revise the flagged prior |