Skip to content

Split View: RWKV: Reinventing RNNs for the Transformer Era — v4에서 v7 Goose까지

✨ Learn with Quiz
|

RWKV: Reinventing RNNs for the Transformer Era — v4에서 v7 Goose까지

RWKV Architecture

논문 개요

  • 제목: RWKV: Reinventing RNNs for the Transformer Era
  • 저자: Bo Peng 외 (RWKV Foundation)
  • 최초 발표: 2023년 5월 (arXiv: 2305.13048), EMNLP 2023
  • 최신 버전: RWKV-7 "Goose" (2025년 3월 발표)
  • 코드: github.com/BlinkDL/RWKV-LM

동기: Transformer vs RNN의 딜레마

현대 LLM은 거의 모두 Transformer 기반이지만 근본적 한계가 있습니다:

  • O(N²) Self-Attention: 시퀀스 길이에 대해 이차 복잡도
  • KV Cache 폭발: 추론 시 메모리가 시퀀스 길이에 비례
  • 긴 컨텍스트 비용: 128K+ 컨텍스트에서 비용/메모리 급증

반면 전통적 RNN은:

  • O(N) 복잡도: 선형 시간/메모리
  • 고정 상태: 일정한 추론 비용
  • BUT: 학습이 병렬화되지 않아 느림, 장기 의존성 포착 취약

RWKV의 질문: "Transformer처럼 병렬 학습하면서 RNN처럼 효율적으로 추론할 수는 없을까?"

RWKV 핵심 아키텍처

RWKV는 이름 자체가 아키텍처의 핵심 요소입니다:

  • R: Receptance (수용 게이트) — 과거 정보 수용 정도 결정
  • W: Weight (시간 감쇠) — 과거 정보의 감쇠율
  • K: Key — 현재 입력의 키
  • V: Value — 현재 입력의 값

WKV (Weighted Key-Value) 메커니즘

RWKV의 핵심인 WKV 연산은 다음과 같습니다:

import torch

def wkv_vanilla(w, u, k, v):
    """
    RWKV의 WKV 메커니즘 (순수 Python 구현)
    w: time decay (음수, 클수록 빠르게 감쇠)
    u: bonus (현재 토큰 보너스)
    k: key
    v: value
    """
    T, C = k.shape
    output = torch.zeros_like(v)

    for c in range(C):
        # 채널별 독립 처리 (O(T) per channel)
        a = 0.0  # 누적 분자
        b = 0.0  # 누적 분모
        p = -1e30  # 최대값 (수치 안정성)

        for t in range(T):
            # 현재 토큰의 기여
            e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))
            e2 = torch.exp(torch.clamp(w[c] + p - p, max=30))  # 이전 누적 감쇠

            # WKV 계산
            wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)
            output[t, c] = wkv

            # 상태 업데이트 (RNN 방식)
            new_p = max(w[c] + p, k[t, c])
            e1 = torch.exp(k[t, c] - new_p)
            e2 = torch.exp(w[c] + p - new_p)

            a = e2 * a + e1 * v[t, c]
            b = e2 * b + e1
            p = new_p

    return output

듀얼 모드: Transformer 모드 vs RNN 모드

# 학습 시: Transformer 모드 (병렬)
# 전체 시퀀스를 한번에 처리, O(T) 복잡도
def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):
    T, C = x.shape
    k = k_proj(x)  # (T, C)
    v = v_proj(x)  # (T, C)
    r = torch.sigmoid(r_proj(x))  # (T, C) 수용 게이트

    # 병렬 WKV 연산 (CUDA 커널)
    wkv = parallel_wkv_cuda(w, u, k, v)

    return r * wkv  # 수용 게이트 적용


# 추론 시: RNN 모드 (순차, 일정 메모리)
def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):
    """한 토큰씩 처리, O(1) 복잡도"""
    k = k_proj(x_t)
    v = v_proj(x_t)
    r = torch.sigmoid(r_proj(x_t))

    # state = (a, b, p) — 고정 크기!
    a, b, p = state

    e1 = torch.exp(u + k - p)
    e2 = torch.exp(w + p - p)

    wkv = (e1 * v + e2 * a) / (e1 + e2 * b)
    output = r * wkv

    # 상태 업데이트
    new_p = torch.maximum(w + p, k)
    e1 = torch.exp(k - new_p)
    e2 = torch.exp(w + p - new_p)
    new_a = e2 * a + e1 * v
    new_b = e2 * b + e1
    new_state = (new_a, new_b, new_p)

    return output, new_state

RWKV 블록 구조

class RWKVBlock:
    """
    RWKV의 기본 블록 구조

    Transformer Block과 비교:
    Transformer: LayerNorm → Attention → Add → LayerNorm → FFN → Add
    RWKV:        LayerNorm → TimeMix  → Add → LayerNorm → ChannelMix → Add
    """

    def __init__(self, dim, layer_id):
        # Time Mixing (Attention 대체)
        self.time_mix = TimeMixing(dim, layer_id)
        # Channel Mixing (FFN 대체)
        self.channel_mix = ChannelMixing(dim, layer_id)
        self.ln1 = LayerNorm(dim)
        self.ln2 = LayerNorm(dim)

    def forward(self, x, state):
        # Time Mixing (과거 토큰 정보 혼합)
        dx, state = self.time_mix(self.ln1(x), state)
        x = x + dx

        # Channel Mixing (채널 간 정보 혼합)
        dx = self.channel_mix(self.ln2(x))
        x = x + dx

        return x, state


class TimeMixing:
    """
    Token Shift + WKV
    현재 토큰과 이전 토큰을 선형 보간하여 사용
    """

    def __init__(self, dim, layer_id):
        self.mix_r = nn.Parameter(torch.ones(dim))  # 보간 비율
        self.mix_k = nn.Parameter(torch.ones(dim))
        self.mix_v = nn.Parameter(torch.ones(dim))

    def forward(self, x, state):
        # Token Shift: 현재 토큰과 이전 토큰의 가중 평균
        x_prev = state.shift  # 이전 토큰
        xr = x * self.mix_r + x_prev * (1 - self.mix_r)
        xk = x * self.mix_k + x_prev * (1 - self.mix_k)
        xv = x * self.mix_v + x_prev * (1 - self.mix_v)

        r = torch.sigmoid(self.W_r(xr))
        k = self.W_k(xk)
        v = self.W_v(xv)

        wkv = compute_wkv(k, v, state)
        return r * wkv, new_state

버전별 진화

RWKV-4 (Eagle)

  • 기본 WKV 메커니즘 도입
  • Token Shift로 위치 인코딩 대체
  • 최대 14B 파라미터

RWKV-5 (Eagle)

# RWKV-5: Multi-headed State
# 여러 개의 독립적인 상태를 유지하여 표현력 향상
class RWKV5_TimeMix:
    def __init__(self, dim, n_heads=8):
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        # 각 헤드가 독립적인 decay rate 보유
        self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))

RWKV-6 (Finch)

# RWKV-6: Data-dependent time decay
# 입력에 따라 감쇠율이 동적으로 변화 (Mamba의 선택적 메커니즘과 유사)
class RWKV6_TimeMix:
    def forward(self, x, state):
        # 고정 decay가 아닌, 입력 의존적 decay
        time_decay = self.W_decay(x)  # 입력마다 다른 감쇠율!
        time_decay = torch.exp(-torch.exp(time_decay))

        # LoRA 스타일 decay 변조
        decay_lora = self.decay_lora_a(x)
        decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b
        time_decay = time_decay + decay_lora

RWKV-7 (Goose) — 2025년 3월

# RWKV-7: State Transition Matrix
# 상태 전이를 행렬로 표현하여 더 풍부한 상태 업데이트
class RWKV7_TimeMix:
    """
    RWKV-7 핵심 혁신:
    1. State Transition Matrix: 스칼라가 아닌 행렬로 상태 전이
    2. In-Context Learning 강화: 동적으로 학습 규칙 조정
    3. Improved Token Mixing: 더 정교한 토큰 간 정보 흐름
    """

    def forward(self, x, state):
        # 상태 전이 행렬 계산
        a = torch.sigmoid(self.W_a(x))  # (B, T, D)
        b = self.W_b(x)                 # (B, T, D)

        # 행렬 형태의 상태 업데이트
        # s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t
        k = self.W_k(x)
        v = self.W_v(x)

        # 상태: (B, H, D/H, D/H) 행렬!
        for t in range(T):
            state = torch.diag(a[:, t]) @ state + \
                    k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)

        return output, state

Transformer와 성능 비교

모델 크기별 성능 (Pile dataset perplexity, ↓ 낮을수록 좋음):

| 모델 크기  | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |
|-----------|-------------|--------|--------|--------|
| 169M      | 17.2        | 18.1   | 17.4   | 17.0   |
| 430M      | 13.8        | 14.5   | 13.9   | 13.6   |
| 1.5B      | 11.2        | 11.8   | 11.3   | 11.0   |
| 3B        | 9.8         | 10.3   | 9.9    | 9.6    |
| 7B        | 8.5         | 9.0    | 8.6    | 8.3    |
| 14B       | 7.8         | 8.2    | 7.9    | 7.6    |

추론 효율 (토큰/, A100 GPU):

| 시퀀스 길이 | Transformer | RWKV-7  |
|-----------|-------------|---------|
| 1K        | 1000        | 1200    |
| 4K        | 800         | 1200    |
| 16K       | 300         | 1200    |
| 64K       | 50          | 1200    |
| 128K+     | OOM         | 1200    |

실전: RWKV 사용하기

HuggingFace에서 사용

from transformers import AutoModelForCausalLM, AutoTokenizer

# RWKV-7 모델 로드
model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True
)

# 텍스트 생성
prompt = "Kubernetes에서 Pod 오토스케일링을 구현하려면"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens=200,
    temperature=0.7,
    top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

RWKV Runner로 로컬 실행

# RWKV Runner 설치 (GUI 도구)
git clone https://github.com/josStorer/RWKV-Runner
cd RWKV-Runner

# 또는 ChatRWKV (CLI)
pip install rwkv

# Python에서 직접 사용
python3 << 'EOF'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

result = pipeline.generate(
    "한국의 AI 산업 현황에 대해 설명해주세요.",
    token_count=200,
    temperature=0.8
)
print(result)
EOF

RWKV vs Mamba vs Transformer

| 특성           | Transformer    | Mamba          | RWKV-7        |
|---------------|----------------|----------------|---------------|
| 학습 복잡도     | O(N²)         | O(N)           | O(N)          |
| 추론 복잡도     | O(N) per token | O(1) per token | O(1) per token|
| 메모리          | O(N) KV Cache  | O(1) 상태     | O(1) 상태     |
| 병렬 학습       ||  (scan)      |  (WKV)      |
| 장기 의존성      | ✅ 강함        | ⭕ 양호        | ⭕ 양호       |
| In-Context Learning | ✅ 강함   | ⭕ 양호        | ✅ v7에서 강화 |
| 구현 복잡도     | 낮음           | 중간 (CUDA)    | 중간 (CUDA)   |
| 커뮤니티       | 거대           | 성장 중        | 활발          |

한계와 미래

현재 한계

  1. 복잡한 검색 태스크: Transformer의 Full Attention 대비 특정 패턴 검색에 약함
  2. CUDA 커널 의존성: 최적 성능을 위해 커스텀 CUDA 커널 필요
  3. 생태계: Transformer 대비 도구/라이브러리 부족

미래 방향

  • 하이브리드 아키텍처: RWKV + 소량 Attention 결합
  • 하드웨어 최적화: Groq, Cerebras 등 새로운 칩에 최적화
  • 멀티모달: 비전, 오디오 등 다른 모달리티 확장

퀴즈

Q1. RWKV에서 R, W, K, V는 각각 무엇을 의미하나요?

R: Receptance (수용 게이트), W: Weight (시간 감쇠), K: Key, V: Value입니다.

Q2. RWKV가 학습과 추론에서 각각 어떤 모드로 동작하나요?

학습 시에는 Transformer처럼 병렬(parallel) 모드로 전체 시퀀스를 한번에 처리하고, 추론 시에는 RNN처럼 순차(sequential) 모드로 토큰당 O(1) 비용으로 동작합니다.

Q3. RWKV-6에서 도입된 Data-dependent time decay란?

기존의 고정된 감쇠율 대신, 입력 데이터에 따라 동적으로 감쇠율이 변하는 메커니즘입니다. Mamba의 선택적(Selective) 메커니즘과 유사한 아이디어입니다.

Q4. RWKV-7 Goose의 핵심 혁신은?

State Transition Matrix를 도입하여 상태 전이를 스칼라가 아닌 행렬로 표현합니다. 이를 통해 더 풍부한 상태 업데이트와 강화된 In-Context Learning이 가능합니다.

Q5. Transformer 대비 RWKV의 가장 큰 장점은?

시퀀스 길이에 관계없이 추론 비용이 일정(O(1) per token)합니다. 128K+ 토큰에서도 Transformer처럼 OOM이 발생하지 않습니다.

Q6. Token Shift 메커니즘의 역할은?

현재 토큰과 이전 토큰을 가중 평균으로 혼합하여, 위치 인코딩 없이도 시퀀스 내 위치 정보를 전달합니다.

Q7. RWKV의 현재 한계 중 하나는?

Transformer의 Full Self-Attention 대비 복잡한 패턴 검색이나 정확한 정보 검색 태스크에서 성능이 떨어질 수 있습니다.

마무리

RWKV는 "RNN은 죽었다"는 통념을 뒤집는 혁신적 아키텍처입니다. Transformer의 병렬 학습 효율성과 RNN의 추론 효율성을 결합하여, 특히 긴 시퀀스 처리와 엣지 디바이스 배포에서 강점을 보입니다. v7 Goose에서 State Transition Matrix 도입으로 Transformer와의 성능 격차가 거의 해소되었습니다.

참고 자료

RWKV: Reinventing RNNs for the Transformer Era — From v4 to v7 Goose

RWKV Architecture

Paper Overview

  • Title: RWKV: Reinventing RNNs for the Transformer Era
  • Authors: Bo Peng et al. (RWKV Foundation)
  • Initial Publication: May 2023 (arXiv: 2305.13048), EMNLP 2023
  • Latest Version: RWKV-7 "Goose" (Released March 2025)
  • Code: github.com/BlinkDL/RWKV-LM

Motivation: The Transformer vs RNN Dilemma

Nearly all modern LLMs are built on Transformers, but they have fundamental limitations:

  • O(N²) Self-Attention: Quadratic complexity with respect to sequence length
  • KV Cache Explosion: Memory scales proportionally to sequence length during inference
  • Long Context Cost: Cost and memory spike dramatically at 128K+ context lengths

On the other hand, traditional RNNs offer:

  • O(N) Complexity: Linear time and memory
  • Fixed State: Constant inference cost
  • BUT: Training is not parallelizable (slow), and they struggle to capture long-range dependencies

RWKV's question: "Can we train in parallel like a Transformer while inferring as efficiently as an RNN?"

RWKV Core Architecture

The name RWKV itself encodes the core components of the architecture:

  • R: Receptance (reception gate) — determines how much past information to accept
  • W: Weight (time decay) — decay rate for past information
  • K: Key — key of the current input
  • V: Value — value of the current input

WKV (Weighted Key-Value) Mechanism

The WKV operation is at the heart of RWKV:

import torch

def wkv_vanilla(w, u, k, v):
    """
    RWKV's WKV mechanism (pure Python implementation)
    w: time decay (negative, larger magnitude = faster decay)
    u: bonus (current token bonus)
    k: key
    v: value
    """
    T, C = k.shape
    output = torch.zeros_like(v)

    for c in range(C):
        # Independent processing per channel (O(T) per channel)
        a = 0.0  # accumulated numerator
        b = 0.0  # accumulated denominator
        p = -1e30  # max value (numerical stability)

        for t in range(T):
            # Current token's contribution
            e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))
            e2 = torch.exp(torch.clamp(w[c] + p - p, max=30))  # decay of previous accumulation

            # WKV computation
            wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)
            output[t, c] = wkv

            # State update (RNN-style)
            new_p = max(w[c] + p, k[t, c])
            e1 = torch.exp(k[t, c] - new_p)
            e2 = torch.exp(w[c] + p - new_p)

            a = e2 * a + e1 * v[t, c]
            b = e2 * b + e1
            p = new_p

    return output

Dual Mode: Transformer Mode vs RNN Mode

# Training: Transformer mode (parallel)
# Process the entire sequence at once, O(T) complexity
def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):
    T, C = x.shape
    k = k_proj(x)  # (T, C)
    v = v_proj(x)  # (T, C)
    r = torch.sigmoid(r_proj(x))  # (T, C) reception gate

    # Parallel WKV operation (CUDA kernel)
    wkv = parallel_wkv_cuda(w, u, k, v)

    return r * wkv  # Apply reception gate


# Inference: RNN mode (sequential, constant memory)
def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):
    """Process one token at a time, O(1) complexity"""
    k = k_proj(x_t)
    v = v_proj(x_t)
    r = torch.sigmoid(r_proj(x_t))

    # state = (a, b, p) — fixed size!
    a, b, p = state

    e1 = torch.exp(u + k - p)
    e2 = torch.exp(w + p - p)

    wkv = (e1 * v + e2 * a) / (e1 + e2 * b)
    output = r * wkv

    # State update
    new_p = torch.maximum(w + p, k)
    e1 = torch.exp(k - new_p)
    e2 = torch.exp(w + p - new_p)
    new_a = e2 * a + e1 * v
    new_b = e2 * b + e1
    new_state = (new_a, new_b, new_p)

    return output, new_state

RWKV Block Structure

class RWKVBlock:
    """
    Basic block structure of RWKV

    Comparison with Transformer Block:
    Transformer: LayerNorm -> Attention -> Add -> LayerNorm -> FFN -> Add
    RWKV:        LayerNorm -> TimeMix   -> Add -> LayerNorm -> ChannelMix -> Add
    """

    def __init__(self, dim, layer_id):
        # Time Mixing (replaces Attention)
        self.time_mix = TimeMixing(dim, layer_id)
        # Channel Mixing (replaces FFN)
        self.channel_mix = ChannelMixing(dim, layer_id)
        self.ln1 = LayerNorm(dim)
        self.ln2 = LayerNorm(dim)

    def forward(self, x, state):
        # Time Mixing (mixes information from past tokens)
        dx, state = self.time_mix(self.ln1(x), state)
        x = x + dx

        # Channel Mixing (mixes information across channels)
        dx = self.channel_mix(self.ln2(x))
        x = x + dx

        return x, state


class TimeMixing:
    """
    Token Shift + WKV
    Linearly interpolates between the current and previous token
    """

    def __init__(self, dim, layer_id):
        self.mix_r = nn.Parameter(torch.ones(dim))  # interpolation ratio
        self.mix_k = nn.Parameter(torch.ones(dim))
        self.mix_v = nn.Parameter(torch.ones(dim))

    def forward(self, x, state):
        # Token Shift: weighted average of current and previous token
        x_prev = state.shift  # previous token
        xr = x * self.mix_r + x_prev * (1 - self.mix_r)
        xk = x * self.mix_k + x_prev * (1 - self.mix_k)
        xv = x * self.mix_v + x_prev * (1 - self.mix_v)

        r = torch.sigmoid(self.W_r(xr))
        k = self.W_k(xk)
        v = self.W_v(xv)

        wkv = compute_wkv(k, v, state)
        return r * wkv, new_state

Version-by-Version Evolution

RWKV-4 (Eagle)

  • Introduced the foundational WKV mechanism
  • Replaced positional encoding with Token Shift
  • Up to 14B parameters

RWKV-5 (Eagle)

# RWKV-5: Multi-headed State
# Maintains multiple independent states for improved expressiveness
class RWKV5_TimeMix:
    def __init__(self, dim, n_heads=8):
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        # Each head has its own independent decay rate
        self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))

RWKV-6 (Finch)

# RWKV-6: Data-dependent time decay
# Decay rate changes dynamically based on input (similar to Mamba's selective mechanism)
class RWKV6_TimeMix:
    def forward(self, x, state):
        # Input-dependent decay instead of fixed decay
        time_decay = self.W_decay(x)  # different decay for each input!
        time_decay = torch.exp(-torch.exp(time_decay))

        # LoRA-style decay modulation
        decay_lora = self.decay_lora_a(x)
        decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b
        time_decay = time_decay + decay_lora

RWKV-7 (Goose) — March 2025

# RWKV-7: State Transition Matrix
# Represents state transitions as matrices for richer state updates
class RWKV7_TimeMix:
    """
    Key innovations in RWKV-7:
    1. State Transition Matrix: state transitions via matrices instead of scalars
    2. Enhanced In-Context Learning: dynamically adjusts learning rules
    3. Improved Token Mixing: more sophisticated inter-token information flow
    """

    def forward(self, x, state):
        # Compute state transition matrix
        a = torch.sigmoid(self.W_a(x))  # (B, T, D)
        b = self.W_b(x)                 # (B, T, D)

        # Matrix-form state update
        # s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t
        k = self.W_k(x)
        v = self.W_v(x)

        # State: (B, H, D/H, D/H) matrix!
        for t in range(T):
            state = torch.diag(a[:, t]) @ state + \
                    k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)

        return output, state

Performance Comparison with Transformers

Performance by model size (Pile dataset perplexity, lower is better):

| Model Size | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |
|-----------|-------------|--------|--------|--------|
| 169M      | 17.2        | 18.1   | 17.4   | 17.0   |
| 430M      | 13.8        | 14.5   | 13.9   | 13.6   |
| 1.5B      | 11.2        | 11.8   | 11.3   | 11.0   |
| 3B        | 9.8         | 10.3   | 9.9    | 9.6    |
| 7B        | 8.5         | 9.0    | 8.6    | 8.3    |
| 14B       | 7.8         | 8.2    | 7.9    | 7.6    |

Inference efficiency (tokens/sec, A100 GPU):

| Sequence Length | Transformer | RWKV-7  |
|----------------|-------------|---------|
| 1K             | 1000        | 1200    |
| 4K             | 800         | 1200    |
| 16K            | 300         | 1200    |
| 64K            | 50          | 1200    |
| 128K+          | OOM         | 1200    |

Practical Usage: Working with RWKV

Using with HuggingFace

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load RWKV-7 model
model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True
)

# Text generation
prompt = "To implement Pod autoscaling in Kubernetes"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens=200,
    temperature=0.7,
    top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Running Locally with RWKV Runner

# Install RWKV Runner (GUI tool)
git clone https://github.com/josStorer/RWKV-Runner
cd RWKV-Runner

# Or ChatRWKV (CLI)
pip install rwkv

# Use directly from Python
python3 << 'EOF'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

result = pipeline.generate(
    "Please explain the current state of the AI industry in Korea.",
    token_count=200,
    temperature=0.8
)
print(result)
EOF

RWKV vs Mamba vs Transformer

| Property            | Transformer    | Mamba          | RWKV-7        |
|---------------------|----------------|----------------|---------------|
| Training Complexity | O(N²)         | O(N)           | O(N)          |
| Inference Complexity| O(N) per token | O(1) per token | O(1) per token|
| Memory              | O(N) KV Cache  | O(1) state     | O(1) state    |
| Parallel Training   | Yes            | Yes (scan)     | Yes (WKV)     |
| Long-range Deps     | Strong         | Good           | Good          |
| In-Context Learning | Strong         | Good           | Enhanced in v7|
| Implementation      | Low            | Medium (CUDA)  | Medium (CUDA) |
| Community           | Massive        | Growing        | Active        |

Limitations and Future Directions

Current Limitations

  1. Complex retrieval tasks: Weaker at specific pattern retrieval compared to Transformer's Full Attention
  2. CUDA kernel dependency: Custom CUDA kernels required for optimal performance
  3. Ecosystem: Fewer tools and libraries compared to Transformers

Future Directions

  • Hybrid architectures: Combining RWKV with a small amount of Attention
  • Hardware optimization: Optimizing for new chips like Groq and Cerebras
  • Multimodal: Expanding to other modalities such as vision and audio

Quiz

Q1. What do R, W, K, and V stand for in RWKV?

R: Receptance (reception gate), W: Weight (time decay), K: Key, V: Value.

Q2. In which mode does RWKV operate during training and inference respectively?

During training, it operates in parallel (Transformer-like) mode processing the entire sequence at once. During inference, it operates in sequential (RNN-like) mode with O(1) cost per token.

Q3. What is the Data-dependent time decay introduced in RWKV-6?

Instead of a fixed decay rate, the decay rate changes dynamically based on the input data. This is an idea similar to Mamba's selective mechanism.

Q4. What is the key innovation of RWKV-7 Goose?

It introduces a State Transition Matrix that represents state transitions using matrices rather than scalars. This enables richer state updates and enhanced In-Context Learning.

Q5. What is the biggest advantage of RWKV over Transformers?

Inference cost remains constant regardless of sequence length (O(1) per token). Even at 128K+ tokens, it does not run into OOM issues like Transformers.

Q6. What is the role of the Token Shift mechanism?

It blends the current token with the previous token via a weighted average, enabling positional information to be conveyed within the sequence without explicit positional encoding.

Q7. What is one current limitation of RWKV?

Compared to Transformer's Full Self-Attention, RWKV may underperform on complex pattern retrieval and precise information retrieval tasks.

Conclusion

RWKV is an innovative architecture that challenges the conventional wisdom that "RNNs are dead." By combining the parallel training efficiency of Transformers with the inference efficiency of RNNs, it excels particularly in long sequence processing and edge device deployment. With the introduction of the State Transition Matrix in v7 Goose, the performance gap with Transformers has been nearly closed.

References