Skip to content

필사 모드: RWKV: Reinventing RNNs for the Transformer Era — From v4 to v7 Goose

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Paper Overview

- **Title**: RWKV: Reinventing RNNs for the Transformer Era

- **Authors**: Bo Peng et al. (RWKV Foundation)

- **Initial Publication**: May 2023 (arXiv: 2305.13048), EMNLP 2023

- **Latest Version**: RWKV-7 "Goose" (Released March 2025)

- **Code**: [github.com/BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM)

Motivation: The Transformer vs RNN Dilemma

Nearly all modern LLMs are built on Transformers, but they have fundamental limitations:

- **O(N²) Self-Attention**: Quadratic complexity with respect to sequence length

- **KV Cache Explosion**: Memory scales proportionally to sequence length during inference

- **Long Context Cost**: Cost and memory spike dramatically at 128K+ context lengths

On the other hand, traditional RNNs offer:

- **O(N) Complexity**: Linear time and memory

- **Fixed State**: Constant inference cost

- **BUT**: Training is not parallelizable (slow), and they struggle to capture long-range dependencies

**RWKV's question**: "Can we train in parallel like a Transformer while inferring as efficiently as an RNN?"

RWKV Core Architecture

The name RWKV itself encodes the core components of the architecture:

- **R**: Receptance (reception gate) — determines how much past information to accept

- **W**: Weight (time decay) — decay rate for past information

- **K**: Key — key of the current input

- **V**: Value — value of the current input

WKV (Weighted Key-Value) Mechanism

The WKV operation is at the heart of RWKV:

def wkv_vanilla(w, u, k, v):

"""

RWKV's WKV mechanism (pure Python implementation)

w: time decay (negative, larger magnitude = faster decay)

u: bonus (current token bonus)

k: key

v: value

"""

T, C = k.shape

output = torch.zeros_like(v)

for c in range(C):

Independent processing per channel (O(T) per channel)

a = 0.0 # accumulated numerator

b = 0.0 # accumulated denominator

p = -1e30 # max value (numerical stability)

for t in range(T):

Current token's contribution

e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))

e2 = torch.exp(torch.clamp(w[c] + p - p, max=30)) # decay of previous accumulation

WKV computation

wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)

output[t, c] = wkv

State update (RNN-style)

new_p = max(w[c] + p, k[t, c])

e1 = torch.exp(k[t, c] - new_p)

e2 = torch.exp(w[c] + p - new_p)

a = e2 * a + e1 * v[t, c]

b = e2 * b + e1

p = new_p

return output

Dual Mode: Transformer Mode vs RNN Mode

Training: Transformer mode (parallel)

Process the entire sequence at once, O(T) complexity

def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):

T, C = x.shape

k = k_proj(x) # (T, C)

v = v_proj(x) # (T, C)

r = torch.sigmoid(r_proj(x)) # (T, C) reception gate

Parallel WKV operation (CUDA kernel)

wkv = parallel_wkv_cuda(w, u, k, v)

return r * wkv # Apply reception gate

Inference: RNN mode (sequential, constant memory)

def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):

"""Process one token at a time, O(1) complexity"""

k = k_proj(x_t)

v = v_proj(x_t)

r = torch.sigmoid(r_proj(x_t))

state = (a, b, p) — fixed size!

a, b, p = state

e1 = torch.exp(u + k - p)

e2 = torch.exp(w + p - p)

wkv = (e1 * v + e2 * a) / (e1 + e2 * b)

output = r * wkv

State update

new_p = torch.maximum(w + p, k)

e1 = torch.exp(k - new_p)

e2 = torch.exp(w + p - new_p)

new_a = e2 * a + e1 * v

new_b = e2 * b + e1

new_state = (new_a, new_b, new_p)

return output, new_state

RWKV Block Structure

class RWKVBlock:

"""

Basic block structure of RWKV

Comparison with Transformer Block:

Transformer: LayerNorm -> Attention -> Add -> LayerNorm -> FFN -> Add

RWKV: LayerNorm -> TimeMix -> Add -> LayerNorm -> ChannelMix -> Add

"""

def __init__(self, dim, layer_id):

Time Mixing (replaces Attention)

self.time_mix = TimeMixing(dim, layer_id)

Channel Mixing (replaces FFN)

self.channel_mix = ChannelMixing(dim, layer_id)

self.ln1 = LayerNorm(dim)

self.ln2 = LayerNorm(dim)

def forward(self, x, state):

Time Mixing (mixes information from past tokens)

dx, state = self.time_mix(self.ln1(x), state)

x = x + dx

Channel Mixing (mixes information across channels)

dx = self.channel_mix(self.ln2(x))

x = x + dx

return x, state

class TimeMixing:

"""

Token Shift + WKV

Linearly interpolates between the current and previous token

"""

def __init__(self, dim, layer_id):

self.mix_r = nn.Parameter(torch.ones(dim)) # interpolation ratio

self.mix_k = nn.Parameter(torch.ones(dim))

self.mix_v = nn.Parameter(torch.ones(dim))

def forward(self, x, state):

Token Shift: weighted average of current and previous token

x_prev = state.shift # previous token

xr = x * self.mix_r + x_prev * (1 - self.mix_r)

xk = x * self.mix_k + x_prev * (1 - self.mix_k)

xv = x * self.mix_v + x_prev * (1 - self.mix_v)

r = torch.sigmoid(self.W_r(xr))

k = self.W_k(xk)

v = self.W_v(xv)

wkv = compute_wkv(k, v, state)

return r * wkv, new_state

Version-by-Version Evolution

RWKV-4 (Eagle)

- Introduced the foundational WKV mechanism

- Replaced positional encoding with Token Shift

- Up to 14B parameters

RWKV-5 (Eagle)

RWKV-5: Multi-headed State

Maintains multiple independent states for improved expressiveness

class RWKV5_TimeMix:

def __init__(self, dim, n_heads=8):

self.n_heads = n_heads

self.head_dim = dim // n_heads

Each head has its own independent decay rate

self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))

RWKV-6 (Finch)

RWKV-6: Data-dependent time decay

Decay rate changes dynamically based on input (similar to Mamba's selective mechanism)

class RWKV6_TimeMix:

def forward(self, x, state):

Input-dependent decay instead of fixed decay

time_decay = self.W_decay(x) # different decay for each input!

time_decay = torch.exp(-torch.exp(time_decay))

LoRA-style decay modulation

decay_lora = self.decay_lora_a(x)

decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b

time_decay = time_decay + decay_lora

RWKV-7 (Goose) — March 2025

RWKV-7: State Transition Matrix

Represents state transitions as matrices for richer state updates

class RWKV7_TimeMix:

"""

Key innovations in RWKV-7:

1. State Transition Matrix: state transitions via matrices instead of scalars

2. Enhanced In-Context Learning: dynamically adjusts learning rules

3. Improved Token Mixing: more sophisticated inter-token information flow

"""

def forward(self, x, state):

Compute state transition matrix

a = torch.sigmoid(self.W_a(x)) # (B, T, D)

b = self.W_b(x) # (B, T, D)

Matrix-form state update

s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t

k = self.W_k(x)

v = self.W_v(x)

State: (B, H, D/H, D/H) matrix!

for t in range(T):

state = torch.diag(a[:, t]) @ state + \

k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)

return output, state

Performance Comparison with Transformers

Performance by model size (Pile dataset perplexity, lower is better):

| Model Size | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |

|-----------|-------------|--------|--------|--------|

| 169M | 17.2 | 18.1 | 17.4 | 17.0 |

| 430M | 13.8 | 14.5 | 13.9 | 13.6 |

| 1.5B | 11.2 | 11.8 | 11.3 | 11.0 |

| 3B | 9.8 | 10.3 | 9.9 | 9.6 |

| 7B | 8.5 | 9.0 | 8.6 | 8.3 |

| 14B | 7.8 | 8.2 | 7.9 | 7.6 |

Inference efficiency (tokens/sec, A100 GPU):

| Sequence Length | Transformer | RWKV-7 |

|----------------|-------------|---------|

| 1K | 1000 | 1200 |

| 4K | 800 | 1200 |

| 16K | 300 | 1200 |

| 64K | 50 | 1200 |

| 128K+ | OOM | 1200 |

Practical Usage: Working with RWKV

Using with HuggingFace

from transformers import AutoModelForCausalLM, AutoTokenizer

Load RWKV-7 model

model = AutoModelForCausalLM.from_pretrained(

"RWKV/rwkv-7-world-3b",

trust_remote_code=True,

torch_dtype=torch.bfloat16,

device_map="auto"

)

tokenizer = AutoTokenizer.from_pretrained(

"RWKV/rwkv-7-world-3b",

trust_remote_code=True

)

Text generation

prompt = "To implement Pod autoscaling in Kubernetes"

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(

**inputs,

max_new_tokens=200,

temperature=0.7,

top_p=0.9

)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Running Locally with RWKV Runner

Install RWKV Runner (GUI tool)

git clone https://github.com/josStorer/RWKV-Runner

cd RWKV-Runner

Or ChatRWKV (CLI)

pip install rwkv

Use directly from Python

python3 << 'EOF'

from rwkv.model import RWKV

from rwkv.utils import PIPELINE

model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')

pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

result = pipeline.generate(

"Please explain the current state of the AI industry in Korea.",

token_count=200,

temperature=0.8

)

print(result)

EOF

RWKV vs Mamba vs Transformer

| Property | Transformer | Mamba | RWKV-7 |

|---------------------|----------------|----------------|---------------|

| Training Complexity | O(N²) | O(N) | O(N) |

| Inference Complexity| O(N) per token | O(1) per token | O(1) per token|

| Memory | O(N) KV Cache | O(1) state | O(1) state |

| Parallel Training | Yes | Yes (scan) | Yes (WKV) |

| Long-range Deps | Strong | Good | Good |

| In-Context Learning | Strong | Good | Enhanced in v7|

| Implementation | Low | Medium (CUDA) | Medium (CUDA) |

| Community | Massive | Growing | Active |

Limitations and Future Directions

Current Limitations

1. **Complex retrieval tasks**: Weaker at specific pattern retrieval compared to Transformer's Full Attention

2. **CUDA kernel dependency**: Custom CUDA kernels required for optimal performance

3. **Ecosystem**: Fewer tools and libraries compared to Transformers

Future Directions

- **Hybrid architectures**: Combining RWKV with a small amount of Attention

- **Hardware optimization**: Optimizing for new chips like Groq and Cerebras

- **Multimodal**: Expanding to other modalities such as vision and audio

Quiz

R: Receptance (reception gate), W: Weight (time decay), K: Key, V: Value.

During training, it operates in parallel (Transformer-like) mode processing the entire sequence at once. During inference, it operates in sequential (RNN-like) mode with O(1) cost per token.

Instead of a fixed decay rate, the decay rate changes dynamically based on the input data. This is an idea similar to Mamba's selective mechanism.

It introduces a State Transition Matrix that represents state transitions using matrices rather than scalars. This enables richer state updates and enhanced In-Context Learning.

Inference cost remains constant regardless of sequence length (O(1) per token). Even at 128K+ tokens, it does not run into OOM issues like Transformers.

It blends the current token with the previous token via a weighted average, enabling positional information to be conveyed within the sequence without explicit positional encoding.

Compared to Transformer's Full Self-Attention, RWKV may underperform on complex pattern retrieval and precise information retrieval tasks.

Conclusion

RWKV is an innovative architecture that challenges the conventional wisdom that "RNNs are dead." By combining the parallel training efficiency of Transformers with the inference efficiency of RNNs, it excels particularly in long sequence processing and edge device deployment. With the introduction of the State Transition Matrix in v7 Goose, the performance gap with Transformers has been nearly closed.

References

- [RWKV Paper (arXiv: 2305.13048)](https://arxiv.org/abs/2305.13048)

- [RWKV-7 Goose Wiki](https://wiki.rwkv.com/)

- [RWKV GitHub](https://github.com/BlinkDL/RWKV-LM)

현재 단락 (1/217)

- **Title**: RWKV: Reinventing RNNs for the Transformer Era

작성 글자: 0원문 글자: 9,685작성 단락: 0/217