- Authors
- Name
- Paper Overview
- Motivation: The Transformer vs RNN Dilemma
- RWKV Core Architecture
- RWKV Block Structure
- Version-by-Version Evolution
- Performance Comparison with Transformers
- Practical Usage: Working with RWKV
- RWKV vs Mamba vs Transformer
- Limitations and Future Directions
- Quiz
- Conclusion
- References

Paper Overview
- Title: RWKV: Reinventing RNNs for the Transformer Era
- Authors: Bo Peng et al. (RWKV Foundation)
- Initial Publication: May 2023 (arXiv: 2305.13048), EMNLP 2023
- Latest Version: RWKV-7 "Goose" (Released March 2025)
- Code: github.com/BlinkDL/RWKV-LM
Motivation: The Transformer vs RNN Dilemma
Nearly all modern LLMs are built on Transformers, but they have fundamental limitations:
- O(N²) Self-Attention: Quadratic complexity with respect to sequence length
- KV Cache Explosion: Memory scales proportionally to sequence length during inference
- Long Context Cost: Cost and memory spike dramatically at 128K+ context lengths
On the other hand, traditional RNNs offer:
- O(N) Complexity: Linear time and memory
- Fixed State: Constant inference cost
- BUT: Training is not parallelizable (slow), and they struggle to capture long-range dependencies
RWKV's question: "Can we train in parallel like a Transformer while inferring as efficiently as an RNN?"
RWKV Core Architecture
The name RWKV itself encodes the core components of the architecture:
- R: Receptance (reception gate) — determines how much past information to accept
- W: Weight (time decay) — decay rate for past information
- K: Key — key of the current input
- V: Value — value of the current input
WKV (Weighted Key-Value) Mechanism
The WKV operation is at the heart of RWKV:
import torch
def wkv_vanilla(w, u, k, v):
"""
RWKV's WKV mechanism (pure Python implementation)
w: time decay (negative, larger magnitude = faster decay)
u: bonus (current token bonus)
k: key
v: value
"""
T, C = k.shape
output = torch.zeros_like(v)
for c in range(C):
# Independent processing per channel (O(T) per channel)
a = 0.0 # accumulated numerator
b = 0.0 # accumulated denominator
p = -1e30 # max value (numerical stability)
for t in range(T):
# Current token's contribution
e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))
e2 = torch.exp(torch.clamp(w[c] + p - p, max=30)) # decay of previous accumulation
# WKV computation
wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)
output[t, c] = wkv
# State update (RNN-style)
new_p = max(w[c] + p, k[t, c])
e1 = torch.exp(k[t, c] - new_p)
e2 = torch.exp(w[c] + p - new_p)
a = e2 * a + e1 * v[t, c]
b = e2 * b + e1
p = new_p
return output
Dual Mode: Transformer Mode vs RNN Mode
# Training: Transformer mode (parallel)
# Process the entire sequence at once, O(T) complexity
def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):
T, C = x.shape
k = k_proj(x) # (T, C)
v = v_proj(x) # (T, C)
r = torch.sigmoid(r_proj(x)) # (T, C) reception gate
# Parallel WKV operation (CUDA kernel)
wkv = parallel_wkv_cuda(w, u, k, v)
return r * wkv # Apply reception gate
# Inference: RNN mode (sequential, constant memory)
def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):
"""Process one token at a time, O(1) complexity"""
k = k_proj(x_t)
v = v_proj(x_t)
r = torch.sigmoid(r_proj(x_t))
# state = (a, b, p) — fixed size!
a, b, p = state
e1 = torch.exp(u + k - p)
e2 = torch.exp(w + p - p)
wkv = (e1 * v + e2 * a) / (e1 + e2 * b)
output = r * wkv
# State update
new_p = torch.maximum(w + p, k)
e1 = torch.exp(k - new_p)
e2 = torch.exp(w + p - new_p)
new_a = e2 * a + e1 * v
new_b = e2 * b + e1
new_state = (new_a, new_b, new_p)
return output, new_state
RWKV Block Structure
class RWKVBlock:
"""
Basic block structure of RWKV
Comparison with Transformer Block:
Transformer: LayerNorm -> Attention -> Add -> LayerNorm -> FFN -> Add
RWKV: LayerNorm -> TimeMix -> Add -> LayerNorm -> ChannelMix -> Add
"""
def __init__(self, dim, layer_id):
# Time Mixing (replaces Attention)
self.time_mix = TimeMixing(dim, layer_id)
# Channel Mixing (replaces FFN)
self.channel_mix = ChannelMixing(dim, layer_id)
self.ln1 = LayerNorm(dim)
self.ln2 = LayerNorm(dim)
def forward(self, x, state):
# Time Mixing (mixes information from past tokens)
dx, state = self.time_mix(self.ln1(x), state)
x = x + dx
# Channel Mixing (mixes information across channels)
dx = self.channel_mix(self.ln2(x))
x = x + dx
return x, state
class TimeMixing:
"""
Token Shift + WKV
Linearly interpolates between the current and previous token
"""
def __init__(self, dim, layer_id):
self.mix_r = nn.Parameter(torch.ones(dim)) # interpolation ratio
self.mix_k = nn.Parameter(torch.ones(dim))
self.mix_v = nn.Parameter(torch.ones(dim))
def forward(self, x, state):
# Token Shift: weighted average of current and previous token
x_prev = state.shift # previous token
xr = x * self.mix_r + x_prev * (1 - self.mix_r)
xk = x * self.mix_k + x_prev * (1 - self.mix_k)
xv = x * self.mix_v + x_prev * (1 - self.mix_v)
r = torch.sigmoid(self.W_r(xr))
k = self.W_k(xk)
v = self.W_v(xv)
wkv = compute_wkv(k, v, state)
return r * wkv, new_state
Version-by-Version Evolution
RWKV-4 (Eagle)
- Introduced the foundational WKV mechanism
- Replaced positional encoding with Token Shift
- Up to 14B parameters
RWKV-5 (Eagle)
# RWKV-5: Multi-headed State
# Maintains multiple independent states for improved expressiveness
class RWKV5_TimeMix:
def __init__(self, dim, n_heads=8):
self.n_heads = n_heads
self.head_dim = dim // n_heads
# Each head has its own independent decay rate
self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))
RWKV-6 (Finch)
# RWKV-6: Data-dependent time decay
# Decay rate changes dynamically based on input (similar to Mamba's selective mechanism)
class RWKV6_TimeMix:
def forward(self, x, state):
# Input-dependent decay instead of fixed decay
time_decay = self.W_decay(x) # different decay for each input!
time_decay = torch.exp(-torch.exp(time_decay))
# LoRA-style decay modulation
decay_lora = self.decay_lora_a(x)
decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b
time_decay = time_decay + decay_lora
RWKV-7 (Goose) — March 2025
# RWKV-7: State Transition Matrix
# Represents state transitions as matrices for richer state updates
class RWKV7_TimeMix:
"""
Key innovations in RWKV-7:
1. State Transition Matrix: state transitions via matrices instead of scalars
2. Enhanced In-Context Learning: dynamically adjusts learning rules
3. Improved Token Mixing: more sophisticated inter-token information flow
"""
def forward(self, x, state):
# Compute state transition matrix
a = torch.sigmoid(self.W_a(x)) # (B, T, D)
b = self.W_b(x) # (B, T, D)
# Matrix-form state update
# s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t
k = self.W_k(x)
v = self.W_v(x)
# State: (B, H, D/H, D/H) matrix!
for t in range(T):
state = torch.diag(a[:, t]) @ state + \
k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)
return output, state
Performance Comparison with Transformers
Performance by model size (Pile dataset perplexity, lower is better):
| Model Size | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |
|-----------|-------------|--------|--------|--------|
| 169M | 17.2 | 18.1 | 17.4 | 17.0 |
| 430M | 13.8 | 14.5 | 13.9 | 13.6 |
| 1.5B | 11.2 | 11.8 | 11.3 | 11.0 |
| 3B | 9.8 | 10.3 | 9.9 | 9.6 |
| 7B | 8.5 | 9.0 | 8.6 | 8.3 |
| 14B | 7.8 | 8.2 | 7.9 | 7.6 |
Inference efficiency (tokens/sec, A100 GPU):
| Sequence Length | Transformer | RWKV-7 |
|----------------|-------------|---------|
| 1K | 1000 | 1200 |
| 4K | 800 | 1200 |
| 16K | 300 | 1200 |
| 64K | 50 | 1200 |
| 128K+ | OOM | 1200 |
Practical Usage: Working with RWKV
Using with HuggingFace
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load RWKV-7 model
model = AutoModelForCausalLM.from_pretrained(
"RWKV/rwkv-7-world-3b",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"RWKV/rwkv-7-world-3b",
trust_remote_code=True
)
# Text generation
prompt = "To implement Pod autoscaling in Kubernetes"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Running Locally with RWKV Runner
# Install RWKV Runner (GUI tool)
git clone https://github.com/josStorer/RWKV-Runner
cd RWKV-Runner
# Or ChatRWKV (CLI)
pip install rwkv
# Use directly from Python
python3 << 'EOF'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
result = pipeline.generate(
"Please explain the current state of the AI industry in Korea.",
token_count=200,
temperature=0.8
)
print(result)
EOF
RWKV vs Mamba vs Transformer
| Property | Transformer | Mamba | RWKV-7 |
|---------------------|----------------|----------------|---------------|
| Training Complexity | O(N²) | O(N) | O(N) |
| Inference Complexity| O(N) per token | O(1) per token | O(1) per token|
| Memory | O(N) KV Cache | O(1) state | O(1) state |
| Parallel Training | Yes | Yes (scan) | Yes (WKV) |
| Long-range Deps | Strong | Good | Good |
| In-Context Learning | Strong | Good | Enhanced in v7|
| Implementation | Low | Medium (CUDA) | Medium (CUDA) |
| Community | Massive | Growing | Active |
Limitations and Future Directions
Current Limitations
- Complex retrieval tasks: Weaker at specific pattern retrieval compared to Transformer's Full Attention
- CUDA kernel dependency: Custom CUDA kernels required for optimal performance
- Ecosystem: Fewer tools and libraries compared to Transformers
Future Directions
- Hybrid architectures: Combining RWKV with a small amount of Attention
- Hardware optimization: Optimizing for new chips like Groq and Cerebras
- Multimodal: Expanding to other modalities such as vision and audio
Quiz
Q1. What do R, W, K, and V stand for in RWKV?
R: Receptance (reception gate), W: Weight (time decay), K: Key, V: Value.
Q2. In which mode does RWKV operate during training and inference respectively?
During training, it operates in parallel (Transformer-like) mode processing the entire sequence at once. During inference, it operates in sequential (RNN-like) mode with O(1) cost per token.
Q3. What is the Data-dependent time decay introduced in RWKV-6?
Instead of a fixed decay rate, the decay rate changes dynamically based on the input data. This is an idea similar to Mamba's selective mechanism.
Q4. What is the key innovation of RWKV-7 Goose?
It introduces a State Transition Matrix that represents state transitions using matrices rather than scalars. This enables richer state updates and enhanced In-Context Learning.
Q5. What is the biggest advantage of RWKV over Transformers?
Inference cost remains constant regardless of sequence length (O(1) per token). Even at 128K+ tokens, it does not run into OOM issues like Transformers.
Q6. What is the role of the Token Shift mechanism?
It blends the current token with the previous token via a weighted average, enabling positional information to be conveyed within the sequence without explicit positional encoding.
Q7. What is one current limitation of RWKV?
Compared to Transformer's Full Self-Attention, RWKV may underperform on complex pattern retrieval and precise information retrieval tasks.
Conclusion
RWKV is an innovative architecture that challenges the conventional wisdom that "RNNs are dead." By combining the parallel training efficiency of Transformers with the inference efficiency of RNNs, it excels particularly in long sequence processing and edge device deployment. With the introduction of the State Transition Matrix in v7 Goose, the performance gap with Transformers has been nearly closed.