[ PROMPT_NODE_27194 ]
lightning_module
[ SKILL_DOCUMENTATION ]
# LightningModule - 全面指南
## 概述
`LightningModule` 将 PyTorch 代码组织为六个逻辑部分,且无需抽象。代码保持为纯 PyTorch,只是结构更加规范。Trainer(训练器)负责处理设备管理、分布式采样和基础设施,同时保留对模型的完全控制。
## 核心结构
python
import lightning as L
import torch
import torch.nn.functional as F
class MyModel(L.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.save_hyperparameters() # 保存初始化参数
self.model = YourNeuralNetwork()
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss)
self.log("val_acc", acc)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("test_loss", loss)
self.log("test_acc", acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss"
}
}
## 核心方法
### 训练流水线方法
#### `training_step(batch, batch_idx)`
计算前向传播并返回损失。Lightning 在自动优化模式下会自动处理反向传播和优化器更新。
**参数:**
- `batch` - 来自 DataLoader 的当前训练批次
- `batch_idx` - 当前批次的索引
**返回:** 损失张量(标量)或包含 'loss' 键的字典
**示例:**
python
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.mse_loss(y_hat, y)
# 记录训练指标
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log("learning_rate", self.optimizers().param_groups[0]['lr'])