[ PROMPT_NODE_27198 ]
trainer
[ SKILL_DOCUMENTATION ]
# Trainer - 全面指南
## 概述
Trainer(训练器)在将 PyTorch 代码组织为 LightningModule 后,可自动执行训练工作流。它自动处理循环细节、设备管理、回调、梯度操作、检查点保存和分布式训练。
## 核心目的
Trainer 管理以下内容:
- 自动启用/禁用梯度
- 运行训练、验证和测试数据加载器
- 在适当时间调用回调
- 将批次放置在正确的设备上
- 编排分布式训练
- 进度条和日志记录
- 检查点保存和提前停止
## 主要方法
### `fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None)`
运行完整的训练例程,包括可选的验证。
**参数:**
- `model` - 要训练的 LightningModule
- `train_dataloaders` - 训练 DataLoader(s)
- `val_dataloaders` - 可选的验证 DataLoader(s)
- `datamodule` - 可选的 LightningDataModule(替换数据加载器)
**示例:**
python
# 使用 DataLoaders
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)
# 使用 DataModule
trainer.fit(model, datamodule=dm)
# 从检查点继续训练
trainer.fit(model, train_loader, ckpt_path="checkpoint.ckpt")
### `validate(model=None, dataloaders=None, datamodule=None)`
在不进行训练的情况下运行验证循环。
**示例:**
python
trainer = L.Trainer()
trainer.validate(model, val_loader)
### `test(model=None, dataloaders=None, datamodule=None)`
运行测试循环。仅在发布结果前使用。
**示例:**
python
trainer = L.Trainer()
trainer.test(model, test_loader)
### `predict(model=None, dataloaders=None, datamodule=None)`
对数据运行推理并返回预测结果。
**示例:**
python
trainer = L.Trainer()
predictions = trainer.predict(model, predict_loader)
## 关键参数
### 训练时长
#### `max_epochs` (int)
训练的最大轮数。默认:1000
python
trainer = L.Trainer(max_epochs=100)
#### `min_epochs` (int)
训练的最小轮数。默认:None
python
trainer = L.Trainer(min_epochs=10, max_epochs=100)
#### `max_steps` (int)
优化器的最大步数。会覆盖 max_epochs。默认:-1(无限制)
python
trainer = L.Trainer(max_steps=10000)
#### `max_time` (str 或 dict)
最大训练时间。适用于有时间限制的集群。
python
# 字符串格式
trainer = L.Trainer(max_time="00:12:00:00") # 12 小时
# 字典格式