[ PROMPT_NODE_22648 ]
Mechanistic Interpretability Transformer Lens API 参考
[ SKILL_DOCUMENTATION ]
# TransformerLens API 参考
## HookedTransformer
机械可解释性的核心类,将 Transformer 模型包装起来,并在每个激活点设置钩子。
### 加载模型
python
from transformer_lens import HookedTransformer
# 基础加载
model = HookedTransformer.from_pretrained("gpt2-small")
# 指定设备/数据类型
model = HookedTransformer.from_pretrained(
"gpt2-medium",
device="cuda",
dtype=torch.float16
)
# 受限访问模型 (LLaMA, Mistral)
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
### from_pretrained() 参数
| 参数 | 类型 | 默认值 | 描述 |
|-----------|------|---------|-------------|
| `model_name` | str | 必填 | 来自 OFFICIAL_MODEL_NAMES 的模型名称 |
| `fold_ln` | bool | True | 将 LayerNorm 权重折叠到后续层 |
| `center_writing_weights` | bool | True | 中心化残差流写入均值 |
| `center_unembed` | bool | True | 中心化反嵌入权重 |
| `dtype` | torch.dtype | None | 模型精度 |
| `device` | str | None | 目标设备 |
| `n_devices` | int | 1 | 模型并行所需的设备数量 |
### 权重矩阵
| 属性 | 形状 | 描述 |
|----------|-------|-------------|
| `W_E` | [d_vocab, d_model] | Token 嵌入矩阵 |
| `W_U` | [d_model, d_vocab] | 反嵌入矩阵 |
| `W_pos` | [n_ctx, d_model] | 位置嵌入 |
| `W_Q` | [n_layers, n_heads, d_model, d_head] | 查询权重 |
| `W_K` | [n_layers, n_heads, d_model, d_head] | 键权重 |
| `W_V` | [n_layers, n_heads, d_model, d_head] | 值权重 |
| `W_O` | [n_layers, n_heads, d_head, d_model] | 输出权重 |
| `W_in` | [n_layers, d_model, d_mlp] | MLP 输入权重 |
| `W_out` | [n_layers, d_mlp, d_model] | MLP 输出权重 |
### 核心方法
#### forward()
python
logits = model(tokens)
logits = model(tokens, return_type="logits")
loss = model(tokens, return_type="loss")
logits, loss = model(tokens, return_type="both")
参数:
- `input`: Token 张量或字符串
- `return_type`: "logits", "loss", "both", 或 None
- `prepend_bos`: 是否在开头添加 BOS token
- `start_at_layer`: 从特定层开始执行
- `stop_at_layer`: 在特定层停止执行
#### run_with_cache()
python
logits, cache = model.run_with_cache(tokens)
# 选择性缓存(节省内存)
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda name: "resid_post" in name
)