[ PROMPT_NODE_27137 ]
Pymc – Workflows
[ SKILL_DOCUMENTATION ]
# PyMC Workflows and Common Patterns
This reference provides standard workflows and patterns for building, validating, and analyzing Bayesian models in PyMC.
## Standard Bayesian Workflow
### Complete Workflow Template
```python
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
# 1. PREPARE DATA
# ===============
X = ... # Predictor variables
y = ... # Observed outcomes
# Standardize predictors for better sampling
X_scaled = (X - X.mean(axis=0)) / X.std(axis=0)
# 2. BUILD MODEL
# ==============
with pm.Model() as model:
# Define coordinates for named dimensions
coords = {
'predictors': ['var1', 'var2', 'var3'],
'obs_id': np.arange(len(y))
}
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X_scaled, beta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
# 3. PRIOR PREDICTIVE CHECK
# ==========================
with model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100)
plt.title('Prior Predictive Check')
plt.show()
# 4. FIT MODEL
# ============
with model:
# Quick VI exploration (optional)
approx = pm.fit(n=20000, random_seed=42)
# Full MCMC inference
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True} # For model comparison
)
# 5. CHECK DIAGNOSTICS
# ====================
# Summary statistics
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
# R-hat and ESS
summary = az.summary(idata)
if (summary['r_hat'] > 1.01).any():
print("WARNING: Some R-hat values > 1.01, chains may not have converged")
if (summary['ess_bulk'] < 400).any():
print("WARNING: Some ESS values 0.7).sum().item()
if high_pareto_k > 0:
print(f"Warning: {name} has {high_pareto_k} observations with high Pareto-k")
```
### Model Weights
```python
# Get model weights (pseudo-BMA)
weights = comparison['weight'].values
print("Model probabilities:")
for name, weight in zip(comparison.index, weights):
print(f" {name}: {weight:.2%}")
# Model averaging (weighted predictions)
def weighted_predictions(idatas, weights):
preds = []
for (name, idata), weight in zip(idatas.items(), weights):
pred = idata.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
preds.append(weight * pred)
return sum(preds)
averaged_pred = weighted_predictions(idatas, weights)
```
## Diagnostics and Troubleshooting
### Diagnosing Sampling Problems
```python
def diagnose_sampling(idata, var_names=None):
"""Comprehensive sampling diagnostics"""
# Check convergence
summary = az.summary(idata, var_names=var_names)
print("=== Convergence Diagnostics ===")
bad_rhat = summary[summary['r_hat'] > 1.01]
if len(bad_rhat) > 0:
print(f"⚠️ {len(bad_rhat)} variables with R-hat > 1.01")
print(bad_rhat[['r_hat']])
else:
print("✓ All R-hat values < 1.01")
# Check effective sample size
print("n=== Effective Sample Size ===")
low_ess = summary[summary['ess_bulk'] 0:
print(f"⚠️ {len(low_ess)} variables with ESS 400")
# Check divergences
print("n=== Divergences ===")
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
print(f"⚠️ {divergences} divergent transitions")
print(" Consider: increase target_accept, reparameterize, or stronger priors")
else:
print("✓ No divergences")
# Check tree depth
print("n=== NUTS Statistics ===")
max_treedepth = idata.sample_stats.tree_depth.max().item()
hits_max = (idata.sample_stats.tree_depth == max_treedepth).sum().item()
if hits_max > 0:
print(f"⚠️ Hit max treedepth {hits_max} times")
print(" Consider: reparameterize or increase max_treedepth")
else:
print(f"✓ No max treedepth issues (max: {max_treedepth})")
return summary
# Usage
diagnose_sampling(idata, var_names=['alpha', 'beta', 'sigma'])
```
### Common Fixes
| Problem | Solution |
|---------|----------|
| Divergences | Increase `target_accept=0.95`, use non-centered parameterization |
| Low ESS | Sample more draws, reparameterize to reduce correlation |
| High R-hat | Run longer chains, check for multimodality, improve initialization |
| Slow sampling | Use ADVI initialization, reparameterize, reduce model complexity |
| Biased posterior | Check prior predictive, ensure likelihood is correct |
## Using Named Dimensions (dims)
### Benefits of dims
- More readable code
- Easier subsetting and analysis
- Better xarray integration
```python
# Define coordinates
coords = {
'predictors': ['age', 'income', 'education'],
'groups': ['A', 'B', 'C'],
'time': pd.date_range('2020-01-01', periods=100, freq='D')
}
with pm.Model(coords=coords) as model:
# Use dims instead of shape
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
alpha = pm.Normal('alpha', mu=0, sigma=1, dims='groups')
y = pm.Normal('y', mu=0, sigma=1, dims=['groups', 'time'], observed=data)
# After sampling, dimensions are preserved
idata = pm.sample()
# Easy subsetting
beta_age = idata.posterior['beta'].sel(predictors='age')
group_A = idata.posterior['alpha'].sel(groups='A')
```
## Saving and Loading Results
```python
# Save InferenceData
idata.to_netcdf('results.nc')
# Load InferenceData
loaded_idata = az.from_netcdf('results.nc')
# Save model for later predictions
import pickle
with open('model.pkl', 'wb') as f:
pickle.dump({'model': model, 'idata': idata}, f)
# Load model
with open('model.pkl', 'rb') as f:
saved = pickle.load(f)
model = saved['model']
idata = saved['idata']
```
Source: claude-code-templates (MIT). See About Us for full credits.