[ PROMPT_NODE_27138 ]
Pymc 工作流
[ SKILL_DOCUMENTATION ]
# PyMC 工作流与通用模式
本参考提供了在 PyMC 中构建、验证和分析贝叶斯模型的标准工作流与模式。
## 标准贝叶斯工作流
### 完整工作流模板
python
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
# 1. 准备数据
# ===============
X = ... # 预测变量
y = ... # 观测结果
# 标准化预测变量以获得更好的采样效果
X_scaled = (X - X.mean(axis=0)) / X.std(axis=0)
# 2. 构建模型
# ==============
with pm.Model() as model:
# 定义命名维度的坐标
coords = {
'predictors': ['var1', 'var2', 'var3'],
'obs_id': np.arange(len(y))
}
# 先验
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# 线性预测器
mu = alpha + pm.math.dot(X_scaled, beta)
# 似然函数
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
# 3. 先验预测检查
# ==========================
with model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# 可视化先验预测
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100)
plt.title('先验预测检查')
plt.show()
# 4. 模型拟合
# ============
with model:
# 快速变分推断(VI)探索(可选)
approx = pm.fit(n=20000, random_seed=42)
# 完整 MCMC 推断
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True} # 用于模型比较
)
# 5. 检查诊断
# ====================
# 汇总统计
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
# R-hat 和 ESS
summary = az.summary(idata)
if (summary['r_hat'] > 1.01).any():
print("警告:部分 R-hat 值 > 1.01,链可能未收敛")
if (summary['ess_bulk'] < 400).any():
print("警告:部分 ESS 值 < 400,考虑增加采样数")
# 检查发散
divergences = idata.sample_stats.diverging.sum().item()
print(f"发散次数: {divergences}")
# 轨迹图
az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'])
plt.tight_layout()
plt.show()
# 6. 后验预测检查
# ==============================
with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# 可视化拟合效果