[ PROMPT_NODE_22824 ]
Optimization Gptq 集成
[ SKILL_DOCUMENTATION ]
# GPTQ 集成指南
与 transformers、PEFT、vLLM 及其他框架的集成。
## Transformers 集成
### 自动检测
python
from transformers import AutoModelForCausalLM
# 自动检测并加载 GPTQ 模型
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-2-13B-GPTQ",
device_map="auto"
)
### 手动加载
python
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
"TheBloke/Llama-2-13B-GPTQ",
device="cuda:0",
use_exllama=True
)
## QLoRA 微调
python
from transformers import AutoModelForCausalLM, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from trl import SFTTrainer
# 加载 GPTQ 模型
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-2-70B-GPTQ",
device_map="auto"
)
# 准备训练
model = prepare_model_for_kbit_training(model)
# LoRA 配置
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# 训练 (在单张 A100 上训练 70B 模型!)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
max_seq_length=2048,
args=TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=2e-4,
num_train_epochs=3,
output_dir="./results"
)
)
trainer.train()
## vLLM 集成
python
from vllm import LLM, SamplingParams
# 在 vLLM 中加载 GPTQ 模型
llm = LLM(
model="TheBloke/Llama-2-70B-GPTQ",
quantization="gptq",
dtype="float16",
gpu_memory_utilization=0.95
)
# 生成
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=200
)
outputs = llm.generate(["Explain AI"], sampling_params)
## 文本生成推理 (TGI)
bash
# 支持 GPTQ 的 Docker
docker run --gpus all -p 8080:80
-v $PWD/data:/data
ghcr.io/huggingface/text-generation-inference:latest
--model-id TheBloke/Llama-2-70B-GPTQ
--quantize gptq
## LangChain 集成
python
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-13B-GPTQ")
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-2-13B-GPTQ",
device_map="auto"
)
pipe = pipeline("text-gener