[ PROMPT_NODE_22732 ]
state-management
[ SKILL_DOCUMENTATION ]
# RWKV 状态管理
## 理解 RWKV 状态
与带有 KV cache 的 Transformer 不同,RWKV 维护一个**固定大小的递归状态**,用于总结之前的所有上下文。
### 状态组件
python
state = {
'att_aa': torch.zeros(n_layers, d_model), # 注意力分子累加器
'att_ab': torch.zeros(n_layers, d_model), # 注意力分母累加器
'att_x_prev': torch.zeros(n_layers, d_model), # 用于 time-mixing 的前一个 x
'ffn_x_prev': torch.zeros(n_layers, d_model) # 用于 channel-mixing 的前一个 x
}
**总状态大小**:`4 × n_layers × d_model` 参数
| 模型 | 层数 | d_model | 状态大小 |
|-------|--------|---------|------------|
| RWKV-169M | 12 | 768 | 37 KB |
| RWKV-430M | 24 | 1024 | 98 KB |
| RWKV-1.5B | 24 | 2048 | 196 KB |
| RWKV-3B | 32 | 2560 | 327 KB |
| RWKV-7B | 32 | 4096 | 524 KB |
| RWKV-14B | 40 | 5120 | 819 KB |
无论上下文长度如何,内存占用都是恒定的!
## 状态初始化
### 零状态 (默认)
python
from rwkv.model import RWKV
model = RWKV(model='/path/to/RWKV-4-Pile-1B5', strategy='cuda fp16')
# 从零状态开始 (无上下文)
state = None
out, state = model.forward(tokens, state)
### 热状态 (预加载上下文)
python
# 加载上下文
context = "The capital of France is Paris. The capital of Germany is Berlin."
context_tokens = tokenizer.encode(context)
# 处理上下文以构建状态
state = None
for token in context_tokens:
_, state = model.forward([token], state)
# 现在使用热状态进行查询
query = " The capital of Italy is"
query_tokens = tokenizer.encode(query)
out, state = model.forward(query_tokens, state)
# 模型“记住”了 Paris 和 Berlin 的示例!
### 共享状态 (多轮对话)
python
# 具有持久状态的对话
state = None
# 第一轮
user1 = "My name is Alice."
tokens1 = tokenizer.encode(user1)
_, state = model.forward(tokens1, state)
# 第二轮
user2 = "What is my name?"
tokens2 = tokenizer.encode(user2)
response, state = model.forward(tokens2, state)
# 回复: "Alice" (状态记住了!)
## 状态更新规则
### Time-Mixing 状态更新
python
# 处理 token t 之前
att_aa_t = att_aa_{t-1} # 前一个分子
att_ab_t = att_ab_{t-1} # 前一个分母
# 计算 WKV
wkv_t = (exp(u) * k_t * v_t + att_aa_t) / (exp(u) * k_t + att_ab_t)
# 更新 token t+1 的状态
w = -exp(time_decay) # 衰减因子
att_aa_{t+1} = exp(w) * att_aa_t + k_t * v_t
att_ab_{t+1} = ex