[ PROMPT_NODE_22358 ]
model-pruning
[ SKILL_DOCUMENTATION ]
# 模型剪枝:压缩大语言模型
## 何时使用此技能
当您需要执行以下操作时,请使用模型剪枝:
- **减小模型大小** 40-60%,且精度损失 <1%
- **加速推理** 使用硬件友好的稀疏性(2-4 倍加速)
- **部署在受限硬件上**(移动设备、边缘设备)
- **无需重新训练即可压缩** 使用一次性(one-shot)方法
- **实现高效服务** 降低内存占用
**关键技术**: Wanda (权重 × 激活值), SparseGPT (二阶), 结构化剪枝, N:M 稀疏度
**论文**: Wanda ICLR 2024 (arXiv 2306.11695), SparseGPT (arXiv 2301.00774)
## 安装
bash
# Wanda 实现
git clone https://github.com/locuslab/wanda
cd wanda
pip install -r requirements.txt
# 可选: SparseGPT
git clone https://github.com/IST-DASLab/sparsegpt
cd sparsegpt
pip install -e .
# 依赖
pip install torch transformers accelerate
## 快速开始
### Wanda 剪枝(一次性,无需重新训练)
**来源**: ICLR 2024 (arXiv 2306.11695)
python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 校准数据(用于激活统计的小数据集)
calib_data = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming the world.",
"Artificial intelligence powers modern applications.",
]
# Wanda 剪枝函数
def wanda_prune(model, calib_data, sparsity=0.5):
"""
Wanda: 按权重幅度 × 输入激活值进行剪枝。
参数:
sparsity: 要剪枝的权重比例 (0.5 = 50%)
"""
# 1. 收集激活统计信息
activations = {}
def hook_fn(name):
def hook(module, input, output):
# 存储输入激活范数
activations[name] = input[0].detach().abs().mean(dim=0)
return hook
# 为所有线性层注册钩子
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(hook_fn(name)))
# 运行校准数据
model.eval()
with torch.no_grad():
for text in calib_data:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
model(**inputs)
# 移除钩子
for hook in hooks:
ho