[ PROMPT_NODE_22704 ]
distributed-training
[ SKILL_DOCUMENTATION ]
# 分布式训练
LitGPT 中用于扩展到多个 GPU 和节点的 FSDP(完全分片数据并行)分布式训练指南。
## 概述
LitGPT 使用 **Lightning Fabric** 和 **FSDP** 将训练任务分布到多个 GPU 上。FSDP 对模型参数、梯度和优化器状态进行分片,从而支持训练超过单 GPU 显存限制的模型。
**何时使用 FSDP**:
- 模型无法放入单个 GPU
- 希望通过多 GPU 加快训练速度
- 训练参数量 >7B 的模型
- 需要跨多个节点进行扩展
## 快速入门
### 单节点多 GPU
bash
# 在 4 个 GPU 上训练 Llama 2 7B
litgpt finetune_lora meta-llama/Llama-2-7b-hf
--devices 4
--data JSON
--data.json_path data/alpaca.json
当 `devices > 1` 时,FSDP 会**自动启用**。
### 多节点训练
bash
# 在 2 个节点上训练,每个节点 8 个 GPU(总共 16 个)
litgpt finetune_lora meta-llama/Llama-2-70b-hf
--devices 8
--num_nodes 2
--data JSON
--data.json_path data/alpaca.json
## FSDP 配置
### 默认 FSDP 策略
当使用多个设备时,LitGPT 应用以下 FSDP 配置:
python
from lightning.fabric.strategies import FSDPStrategy
from litgpt.model import Block
strategy = FSDPStrategy(
auto_wrap_policy={Block},
state_dict_type="full",
sharding_strategy="HYBRID_SHARD"
)
**参数**:
- `auto_wrap_policy={Block}`:自动使用 FSDP 包装每个 transformer `Block`
- `state_dict_type="full"`:保存完整模型(在 rank 0 上组装)以便于部署
- `sharding_strategy="HYBRID_SHARD"`:对参数、梯度和优化器状态进行分片
### 分片策略
| 策略 | 分片内容 | 通信 | 使用场景 |
|----------|--------|---------------|----------|
| `FULL_SHARD` (ZeRO-3) | 参数 + 梯度 + 优化器 | 前向/后向传播前 All-gather | 最大限度节省内存 |
| `SHARD_GRAD_OP` (ZeRO-2) | 仅梯度 + 优化器 | 后向传播后 Reduce-scatter | 比 FULL_SHARD 更快 |
| `HYBRID_SHARD` (默认) | 全部(跨节点混合) | 针对多节点优化 | 集群环境最佳 |
| `NO_SHARD` | 无 | 广播 | 单 GPU(无 FSDP) |
**建议**:多节点使用默认的 `HYBRID_SHARD`,单节点多 GPU 使用 `FULL_SHARD`。
### 状态字典类型
| 类型 | 行为 | 使用场景 |
|------|----------|----------|
| `full` (默认) | 在 rank 0 上收集所有分片,保存为单个文件 | 易于部署、推理 |
| `sharded` | 每个 rank 分别保存其分片 | 更快的检查点保存