[ PROMPT_NODE_22626 ]
Mechanistic Interpretability Nnsight 教程
[ SKILL_DOCUMENTATION ]
# nnsight 教程
## 教程 1:基础激活分析
### 目标
加载模型,访问内部激活值并进行分析。
### 步骤
python
from nnsight import LanguageModel
import torch
# 1. 加载模型
model = LanguageModel("openai-community/gpt2", device_map="auto")
# 2. 追踪并收集激活值
prompt = "The capital of France is"
with model.trace(prompt) as tracer:
# 从多个层收集
activations = {}
for i in range(12): # GPT-2 有 12 层
activations[i] = model.transformer.h[i].output[0].save()
# 获取最终 logits
logits = model.output.save()
# 3. 分析(在上下文之外)
print("逐层激活范数:")
for layer, act in activations.items():
print(f" Layer {layer}: {act.norm().item():.2f}")
# 4. 检查预测结果
probs = torch.softmax(logits[0, -1], dim=-1)
top_tokens = probs.topk(5)
print("nTop 预测:")
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
token_str = model.tokenizer.decode(token_id)
print(f" {token_str!r}: {prob.item():.3f}")
---
## 教程 2:激活修补 (Activation Patching)
### 目标
将一个提示词的激活值修补到另一个提示词中,以测试因果关系。
### 步骤
python
from nnsight import LanguageModel
import torch
model = LanguageModel("gpt2", device_map="auto")
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
# 1. 获取干净的激活值
with model.trace(clean_prompt) as tracer:
clean_hidden = model.transformer.h[8].output[0].save()
clean_logits = model.output.save()
# 2. 定义度量指标
paris_token = model.tokenizer.encode(" Paris")[0]
rome_token = model.tokenizer.encode(" Rome")[0]
def logit_diff(logits):
return (logits[0, -1, paris_token] - logits[0, -1, rome_token]).item()
print(f"干净的 logit 差值: {logit_diff(clean_logits):.3f}")
# 3. 将干净的激活值修补到损坏的提示词中
with model.trace(corrupted_prompt) as tracer:
# 用干净的激活值替换第 8 层输出
model.transformer.h[8].output[0][:] = clean_hidden
patched_logits = model.output.save()
print(f"修补后的 logit 差值: {logit_diff(patched_logits):.3f}")
# 4. 系统性修补扫描
results = torch.zeros(12) # 12 层
for layer in range(12):
# 获取该层的干净激活值
with model.trace(clean_prompt) as tracer:
clean_act = model.transformer.h[layer].output[0].save()
# 修补到损坏的提示词中
with model.trace(corrupted_prompt) as tra