[ PROMPT_NODE_22338 ]
knowledge-distillation
[ SKILL_DOCUMENTATION ]
# 知识蒸馏:压缩大语言模型 (LLM)
## 何时使用此技能
当你需要以下操作时,请使用知识蒸馏:
- **压缩模型**:从 70B 压缩至 7B,同时保留 90% 以上的性能
- **迁移能力**:将专有模型 (GPT-4) 的能力迁移至开源模型 (LLaMA, Mistral)
- **降低推理成本**:通过部署较小的学生模型
- **创建专用模型**:通过蒸馏特定领域的知识
- **改进小模型**:使用来自大型教师模型的合成数据
**关键技术**: 温度缩放 (Temperature scaling)、软目标 (Soft targets)、反向 KLD (MiniLLM)、Logit 蒸馏、响应蒸馏
**论文**: Hinton et al. 2015 (arXiv 1503.02531), MiniLLM (arXiv 2306.08543), KD Survey (arXiv 2402.13116)
## 安装
bash
# 标准 transformers
pip install transformers datasets accelerate
# 用于训练
pip install torch deepspeed wandb
# 可选: MiniLLM 实现
git clone https://github.com/microsoft/LMOps
cd LMOps/minillm
pip install -e .
## 快速开始
### 基础知识蒸馏
python
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
# 1. 加载教师 (大型) 和学生 (小型) 模型
teacher = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf", # 大型教师模型
torch_dtype=torch.float16,
device_map="auto"
)
student = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", # 小型学生模型
torch_dtype=torch.float16,
device_map="cuda:0"
)
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-70b-hf"
)
# 2. 定义蒸馏损失函数
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
"""
结合硬损失 (交叉熵) 与软损失 (KL 散度)。
参数:
temperature: 平滑概率分布 (值越大越平滑)
alpha: 蒸馏损失的权重 (1-alpha 为硬损失权重)
"""
# 硬损失: 使用真实标签的标准交叉熵
hard_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))
# 软损失: 学生与教师之间的 KL 散度
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (temperature ** 2)
# 组合损失
return alpha * soft_loss + (1 - alpha) * hard_loss
# 3. Tr