[ PROMPT_NODE_22634 ]
Mechanistic Interpretability Pyvene 教程
[ SKILL_DOCUMENTATION ]
# pyvene 教程
## 教程 1:基础激活修补 (Activation Patching)
### 目标
在两个提示词之间交换激活值,以测试因果关系。
### 分步指南
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# 1. 加载模型
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 2. 准备输入
base_prompt = "The Colosseum is in the city of"
source_prompt = "The Eiffel Tower is in the city of"
base_inputs = tokenizer(base_prompt, return_tensors="pt")
source_inputs = tokenizer(source_prompt, return_tensors="pt")
# 3. 定义干预(修补第 8 层)
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 4. 执行干预
_, patched_outputs = intervenable(
base=base_inputs,
sources=[source_inputs],
)
# 5. 检查预测结果
patched_logits = patched_outputs.logits
probs = torch.softmax(patched_logits[0, -1], dim=-1)
rome_token = tokenizer.encode(" Rome")[0]
paris_token = tokenizer.encode(" Paris")[0]
print(f"P(Rome): {probs[rome_token].item():.4f}")
print(f"P(Paris): {probs[paris_token].item():.4f}")
---
## 教程 2:因果追踪 (ROME 风格)
### 目标
通过破坏输入并恢复激活值,定位事实关联存储的位置。
### 分步指南
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
# 1. 定义提示词
clean_prompt = "The Space Needle is in downtown"
# 我们将通过向嵌入层添加噪声来破坏输入
clean_inputs = tokenizer(clean_prompt, return_tensors="pt")
seattle_token = tokenizer.encode(" Seattle")[0]
# 2. 获取干净的基准线
with torch.no_grad():
clean_outputs = model(**clean_inputs)
clean_prob = torch.softmax(clean_outputs.logits[0, -1], dim=-1)[seattle_token].item()
print(f"Clean P(Seattle): {clean_prob:.4f}")
# 3. 遍历各层 - 破坏输入,并在每一层进行恢复
results = []
for restore_layer in range(model.config.n_layer):
# 配置:在输入处添加噪声,在目标层恢复
config = pv.IntervenableConfig(
representations=[
# 嵌入层的噪声干预
pv