[ PROMPT_NODE_27190 ]
data_module
[ SKILL_DOCUMENTATION ]
# LightningDataModule - 综合指南
## 概述
LightningDataModule 是一个可重用、可共享的类,它封装了 PyTorch Lightning 中的所有数据处理步骤。它通过标准化数据集在项目间的管理和共享方式,解决了数据准备逻辑分散的问题。
## 它解决的核心问题
在传统的 PyTorch 工作流中,数据处理分散在多个文件中,导致难以回答以下问题:
- “你使用了什么数据划分?”
- “应用了什么转换?”
- “数据是如何准备的?”
DataModules 将这些信息集中起来,以实现可复现性和可重用性。
## 五个处理步骤
DataModule 将数据处理组织为五个阶段:
1. **下载/分词/处理** - 初始数据获取
2. **清理并保存** - 将处理后的数据持久化到磁盘
3. **加载到 Dataset** - 创建 PyTorch Dataset 对象
4. **应用转换** - 数据增强、归一化等
5. **封装在 DataLoader 中** - 配置批处理和加载
## 主要方法
### `prepare_data()`
下载并处理数据。仅在单个进程上运行一次(非分布式)。
**用途:**
- 下载数据集
- 对文本进行分词
- 将处理后的数据保存到磁盘
**重要:** 不要在此处设置状态(例如 self.x = y)。状态不会传输到其他进程。
**示例:**
python
def prepare_data(self):
# 下载数据(运行一次)
download_dataset("http://example.com/data.zip", "data/")
# 分词并保存(运行一次)
tokenize_and_save("data/raw/", "data/processed/")
### `setup(stage)`
创建数据集并应用转换。在分布式训练的每个进程上运行。
**参数:**
- `stage` - 'fit', 'validate', 'test', 或 'predict'
**用途:**
- 创建训练/验证/测试划分
- 构建 Dataset 对象
- 应用转换
- 设置状态 (self.train_dataset = ...)
**示例:**
python
def setup(self, stage):
if stage == 'fit':
full_dataset = MyDataset("data/processed/")
self.train_dataset, self.val_dataset = random_split(
full_dataset, [0.8, 0.2]
)
if stage == 'test':
self.test_dataset = MyDataset("data/processed/test/")
if stage == 'predict':
self.predict_dataset = MyDataset("data/processed/predict/")
### `train_dataloader()`
返回训练 DataLoader。
**示例:**
python
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,