[ PROMPT_NODE_22328 ]
Distributed Training Pytorch Lightning 回调
[ SKILL_DOCUMENTATION ]
# PyTorch Lightning 回调 (Callbacks)
## 概述
回调在不修改 LightningModule 的情况下为训练添加功能。它们捕获**非核心逻辑**,如检查点保存、早停和日志记录。
## 内置回调
### 1. ModelCheckpoint
**在训练期间保存最佳模型**:
python
from lightning.pytorch.callbacks import ModelCheckpoint
# 根据验证损失保存前 3 个模型
checkpoint = ModelCheckpoint(
dirpath='checkpoints/',
filename='model-{epoch:02d}-{val_loss:.2f}',
monitor='val_loss',
mode='min',
save_top_k=3,
save_last=True, # 同时保存最后一个 epoch
verbose=True
)
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader)
**配置选项**:
python
checkpoint = ModelCheckpoint(
monitor='val_acc', # 要监控的指标
mode='max', # 'max' 表示准确率,'min' 表示损失
save_top_k=5, # 保留最好的 5 个模型
save_last=True, # 单独保存最后一个 epoch
every_n_epochs=1, # 每 N 个 epoch 保存一次
save_on_train_epoch_end=False, # 改为在验证结束时保存
filename='best-{epoch}-{val_acc:.3f}', # 命名模式
auto_insert_metric_name=False # 不要在文件名中自动添加指标名称
)
**加载检查点**:
python
# 加载最佳模型
best_model_path = checkpoint.best_model_path
model = LitModel.load_from_checkpoint(best_model_path)
# 恢复训练
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
### 2. EarlyStopping
**当指标停止改善时停止训练**:
python
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss',
patience=5, # 等待 5 个 epoch
mode='min',
min_delta=0.001, # 最小变化量,视为改善
verbose=True,
strict=True, # 如果找不到监控指标则报错
check_on_train_epoch_end=False # 在验证结束时检查
)
trainer = L.Trainer(callbacks=[early_stop])
trainer.fit(model, train_loader, val_loader)
# 如果 5 个 epoch 内没有改善,则自动停止
**高级用法**:
python
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
min_delta=0.0,
verbose=True,
mode='min',
stopping_threshold=0.1, # 如果 val_loss 5.0 则停止
check_finite=True # 检查是否有限值