[ PROMPT_NODE_27186 ]
best_practices
[ SKILL_DOCUMENTATION ]
# 最佳实践 - PyTorch Lightning
## 代码组织
### 1. 将研究与工程分离
**推荐:**
python
class MyModel(L.LightningModule):
# 研究代码(模型的功能)
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
return loss
# 工程代码(如何训练) - 在 Trainer 中
trainer = L.Trainer(
max_epochs=100,
accelerator="gpu",
devices=4,
strategy="ddp"
)
**不推荐:**
python
# 混合了研究和工程逻辑
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# 不要手动进行设备管理
loss = loss.cuda()
# 不要手动执行优化器步骤(除非是手动优化)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
### 2. 使用 LightningDataModule
**推荐:**
python
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def prepare_data(self):
# 下载数据一次
download_data(self.data_dir)
def setup(self, stage):
# 每个进程加载数据
self.train_dataset = MyDataset(self.data_dir, split='train')
self.val_dataset = MyDataset(self.data_dir, split='val')
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
# 可重用且易于共享
dm = MyDataModule("./data", batch_size=32)
trainer.fit(model, datamodule=dm)
**不推荐:**
python
# 分散的数据逻辑
train_dataset = load_data()
val_dataset = load_data()
train_loader = DataLoader(train_dataset, ...)
val_loader = DataLoader(val_dataset, ...)
trainer.fit(model, train_loader, val_loader)
### 3. 保持模型模块化
python
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(...)
def forward(self, x):
return self.layers(x)
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
z = self.encoder(x)
ret