[ PROMPT_NODE_22731 ]
State Management
[ SKILL_DOCUMENTATION ]
# RWKV State Management
## Understanding RWKV State
Unlike Transformers with KV cache, RWKV maintains a **fixed-size recurrent state** that summarizes all previous context.
### State Components
```python
state = {
'att_aa': torch.zeros(n_layers, d_model), # Attention numerator accumulator
'att_ab': torch.zeros(n_layers, d_model), # Attention denominator accumulator
'att_x_prev': torch.zeros(n_layers, d_model), # Previous x for time-mixing
'ffn_x_prev': torch.zeros(n_layers, d_model) # Previous x for channel-mixing
}
```
**Total state size**: `4 × n_layers × d_model` parameters
| Model | Layers | d_model | State Size |
|-------|--------|---------|------------|
| RWKV-169M | 12 | 768 | 37 KB |
| RWKV-430M | 24 | 1024 | 98 KB |
| RWKV-1.5B | 24 | 2048 | 196 KB |
| RWKV-3B | 32 | 2560 | 327 KB |
| RWKV-7B | 32 | 4096 | 524 KB |
| RWKV-14B | 40 | 5120 | 819 KB |
**Constant memory** regardless of context length!
## State Initialization
### Zero State (Default)
```python
from rwkv.model import RWKV
model = RWKV(model='/path/to/RWKV-4-Pile-1B5', strategy='cuda fp16')
# Start with zero state (no context)
state = None
out, state = model.forward(tokens, state)
```
### Warm State (Preloaded Context)
```python
# Load context once
context = "The capital of France is Paris. The capital of Germany is Berlin."
context_tokens = tokenizer.encode(context)
# Process context to build state
state = None
for token in context_tokens:
_, state = model.forward([token], state)
# Now use warm state for queries
query = " The capital of Italy is"
query_tokens = tokenizer.encode(query)
out, state = model.forward(query_tokens, state)
# Model "remembers" Paris and Berlin examples!
```
### Shared State (Multi-turn Conversations)
```python
# Conversation with persistent state
state = None
# Turn 1
user1 = "My name is Alice."
tokens1 = tokenizer.encode(user1)
_, state = model.forward(tokens1, state)
# Turn 2
user2 = "What is my name?"
tokens2 = tokenizer.encode(user2)
response, state = model.forward(tokens2, state)
# Response: "Alice" (state remembers!)
```
## State Update Rules
### Time-Mixing State Update
```python
# Before processing token t
att_aa_t = att_aa_{t-1} # Previous numerator
att_ab_t = att_ab_{t-1} # Previous denominator
# Compute WKV
wkv_t = (exp(u) * k_t * v_t + att_aa_t) / (exp(u) * k_t + att_ab_t)
# Update state for token t+1
w = -exp(time_decay) # Decay factor
att_aa_{t+1} = exp(w) * att_aa_t + k_t * v_t
att_ab_{t+1} = exp(w) * att_ab_t + k_t
att_x_prev_{t+1} = x_t
```
**Effect of time_decay**:
- **w = -0.01** (small decay): State decays slowly → long memory
- **w = -5.0** (large decay): State decays quickly → short memory
### Channel-Mixing State Update
```python
# Simply store previous x for next token
ffn_x_prev_{t+1} = x_t
```
## State Serialization
### Save/Load State (PyTorch)
```python
import torch
# Save conversation state
state_dict = {
'att_aa': state[0],
'att_ab': state[1],
'att_x_prev': state[2],
'ffn_x_prev': state[3]
}
torch.save(state_dict, 'conversation_123.pt')
# Load state
loaded = torch.load('conversation_123.pt')
state = (loaded['att_aa'], loaded['att_ab'], loaded['att_x_prev'], loaded['ffn_x_prev'])
# Continue conversation
out, state = model.forward(new_tokens, state)
```
### State Compression (Optional)
```python
# FP16 state (half size)
state_fp16 = tuple(s.half() for s in state)
torch.save(state_fp16, 'state_compressed.pt')
# Restore
state = tuple(s.float() for s in torch.load('state_compressed.pt'))
```
## Multi-Session State Management
### Session State Store
```python
class StateManager:
def __init__(self):
self.sessions = {} # session_id -> state
def get_state(self, session_id):
return self.sessions.get(session_id, None)
def save_state(self, session_id, state):
self.sessions[session_id] = state
def clear_session(self, session_id):
if session_id in self.sessions:
del self.sessions[session_id]
# Usage
manager = StateManager()
# User 1 conversation
state1 = manager.get_state('user_1')
out1, state1 = model.forward(tokens1, state1)
manager.save_state('user_1', state1)
# User 2 conversation (independent state)
state2 = manager.get_state('user_2')
out2, state2 = model.forward(tokens2, state2)
manager.save_state('user_2', state2)
```
### State Expiration
```python
import time
class StateManagerWithExpiry:
def __init__(self, expiry_seconds=3600):
self.sessions = {} # session_id -> (state, timestamp)
self.expiry = expiry_seconds
def get_state(self, session_id):
if session_id in self.sessions:
state, timestamp = self.sessions[session_id]
if time.time() - timestamp 1e10, numerical precision issues
# Solution 1: Periodic normalization
if att_aa.abs().max() > 1e6:
scale = att_aa.abs().max()
att_aa = att_aa / scale
att_ab = att_ab / scale
```
### Underflow Prevention
```python
# Issue: With large negative time_decay, state can underflow to 0
# Solution: Clip time_decay
time_decay = torch.clamp(time_decay, min=-8.0, max=-0.1)
# Ensures state doesn't decay too fast
```
## State vs KV Cache Comparison
### Memory Usage (8K context)
| Model Type | Model Size | KV Cache Size | RWKV State Size |
|------------|------------|---------------|-----------------|
| Transformer | 1.3B | 4.1 GB | - |
| **RWKV** | **1.5B** | **-** | **196 KB** |
| Transformer | 7B | 21.3 GB | - |
| **RWKV** | **7B** | **-** | **524 KB** |
**RWKV advantage**: 10,000× smaller than KV cache!
### Information Retention
**KV Cache (Transformer)**:
- Perfect: Stores all previous keys and values
- Retrieval: Exact attention to any previous token
- Cost: O(n) memory growth
**RWKV State**:
- Lossy: Compressed representation of history
- Retrieval: Weighted blend of previous tokens (decay-based)
- Cost: O(1) constant memory
**Trade-off**: RWKV sacrifices perfect recall for constant memory
## Resources
- State management examples: https://github.com/BlinkDL/ChatRWKV
- Wiki: https://wiki.rwkv.com/state-management
- Discord: https://discord.gg/bDSBUMeFpc (RWKV community)
Source: claude-code-templates (MIT). See About Us for full credits.