Skip to content
Published on

RWKV: Reinventing RNNs for the Transformer Era — v4에서 v7 Goose까지

Authors
  • Name
    Twitter
RWKV Architecture

논문 개요

  • 제목: RWKV: Reinventing RNNs for the Transformer Era
  • 저자: Bo Peng 외 (RWKV Foundation)
  • 최초 발표: 2023년 5월 (arXiv: 2305.13048), EMNLP 2023
  • 최신 버전: RWKV-7 "Goose" (2025년 3월 발표)
  • 코드: github.com/BlinkDL/RWKV-LM

동기: Transformer vs RNN의 딜레마

현대 LLM은 거의 모두 Transformer 기반이지만 근본적 한계가 있습니다:

  • O(N²) Self-Attention: 시퀀스 길이에 대해 이차 복잡도
  • KV Cache 폭발: 추론 시 메모리가 시퀀스 길이에 비례
  • 긴 컨텍스트 비용: 128K+ 컨텍스트에서 비용/메모리 급증

반면 전통적 RNN은:

  • O(N) 복잡도: 선형 시간/메모리
  • 고정 상태: 일정한 추론 비용
  • BUT: 학습이 병렬화되지 않아 느림, 장기 의존성 포착 취약

RWKV의 질문: "Transformer처럼 병렬 학습하면서 RNN처럼 효율적으로 추론할 수는 없을까?"

RWKV 핵심 아키텍처

RWKV는 이름 자체가 아키텍처의 핵심 요소입니다:

  • R: Receptance (수용 게이트) — 과거 정보 수용 정도 결정
  • W: Weight (시간 감쇠) — 과거 정보의 감쇠율
  • K: Key — 현재 입력의 키
  • V: Value — 현재 입력의 값

WKV (Weighted Key-Value) 메커니즘

RWKV의 핵심인 WKV 연산은 다음과 같습니다:

import torch

def wkv_vanilla(w, u, k, v):
    """
    RWKV의 WKV 메커니즘 (순수 Python 구현)
    w: time decay (음수, 클수록 빠르게 감쇠)
    u: bonus (현재 토큰 보너스)
    k: key
    v: value
    """
    T, C = k.shape
    output = torch.zeros_like(v)

    for c in range(C):
        # 채널별 독립 처리 (O(T) per channel)
        a = 0.0  # 누적 분자
        b = 0.0  # 누적 분모
        p = -1e30  # 최대값 (수치 안정성)

        for t in range(T):
            # 현재 토큰의 기여
            e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))
            e2 = torch.exp(torch.clamp(w[c] + p - p, max=30))  # 이전 누적 감쇠

            # WKV 계산
            wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)
            output[t, c] = wkv

            # 상태 업데이트 (RNN 방식)
            new_p = max(w[c] + p, k[t, c])
            e1 = torch.exp(k[t, c] - new_p)
            e2 = torch.exp(w[c] + p - new_p)

            a = e2 * a + e1 * v[t, c]
            b = e2 * b + e1
            p = new_p

    return output

듀얼 모드: Transformer 모드 vs RNN 모드

# 학습 시: Transformer 모드 (병렬)
# 전체 시퀀스를 한번에 처리, O(T) 복잡도
def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):
    T, C = x.shape
    k = k_proj(x)  # (T, C)
    v = v_proj(x)  # (T, C)
    r = torch.sigmoid(r_proj(x))  # (T, C) 수용 게이트

    # 병렬 WKV 연산 (CUDA 커널)
    wkv = parallel_wkv_cuda(w, u, k, v)

    return r * wkv  # 수용 게이트 적용


# 추론 시: RNN 모드 (순차, 일정 메모리)
def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):
    """한 토큰씩 처리, O(1) 복잡도"""
    k = k_proj(x_t)
    v = v_proj(x_t)
    r = torch.sigmoid(r_proj(x_t))

    # state = (a, b, p) — 고정 크기!
    a, b, p = state

    e1 = torch.exp(u + k - p)
    e2 = torch.exp(w + p - p)

    wkv = (e1 * v + e2 * a) / (e1 + e2 * b)
    output = r * wkv

    # 상태 업데이트
    new_p = torch.maximum(w + p, k)
    e1 = torch.exp(k - new_p)
    e2 = torch.exp(w + p - new_p)
    new_a = e2 * a + e1 * v
    new_b = e2 * b + e1
    new_state = (new_a, new_b, new_p)

    return output, new_state

RWKV 블록 구조

class RWKVBlock:
    """
    RWKV의 기본 블록 구조

    Transformer Block과 비교:
    Transformer: LayerNorm → Attention → Add → LayerNorm → FFN → Add
    RWKV:        LayerNorm → TimeMix  → Add → LayerNorm → ChannelMix → Add
    """

    def __init__(self, dim, layer_id):
        # Time Mixing (Attention 대체)
        self.time_mix = TimeMixing(dim, layer_id)
        # Channel Mixing (FFN 대체)
        self.channel_mix = ChannelMixing(dim, layer_id)
        self.ln1 = LayerNorm(dim)
        self.ln2 = LayerNorm(dim)

    def forward(self, x, state):
        # Time Mixing (과거 토큰 정보 혼합)
        dx, state = self.time_mix(self.ln1(x), state)
        x = x + dx

        # Channel Mixing (채널 간 정보 혼합)
        dx = self.channel_mix(self.ln2(x))
        x = x + dx

        return x, state


class TimeMixing:
    """
    Token Shift + WKV
    현재 토큰과 이전 토큰을 선형 보간하여 사용
    """

    def __init__(self, dim, layer_id):
        self.mix_r = nn.Parameter(torch.ones(dim))  # 보간 비율
        self.mix_k = nn.Parameter(torch.ones(dim))
        self.mix_v = nn.Parameter(torch.ones(dim))

    def forward(self, x, state):
        # Token Shift: 현재 토큰과 이전 토큰의 가중 평균
        x_prev = state.shift  # 이전 토큰
        xr = x * self.mix_r + x_prev * (1 - self.mix_r)
        xk = x * self.mix_k + x_prev * (1 - self.mix_k)
        xv = x * self.mix_v + x_prev * (1 - self.mix_v)

        r = torch.sigmoid(self.W_r(xr))
        k = self.W_k(xk)
        v = self.W_v(xv)

        wkv = compute_wkv(k, v, state)
        return r * wkv, new_state

버전별 진화

RWKV-4 (Eagle)

  • 기본 WKV 메커니즘 도입
  • Token Shift로 위치 인코딩 대체
  • 최대 14B 파라미터

RWKV-5 (Eagle)

# RWKV-5: Multi-headed State
# 여러 개의 독립적인 상태를 유지하여 표현력 향상
class RWKV5_TimeMix:
    def __init__(self, dim, n_heads=8):
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        # 각 헤드가 독립적인 decay rate 보유
        self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))

RWKV-6 (Finch)

# RWKV-6: Data-dependent time decay
# 입력에 따라 감쇠율이 동적으로 변화 (Mamba의 선택적 메커니즘과 유사)
class RWKV6_TimeMix:
    def forward(self, x, state):
        # 고정 decay가 아닌, 입력 의존적 decay
        time_decay = self.W_decay(x)  # 입력마다 다른 감쇠율!
        time_decay = torch.exp(-torch.exp(time_decay))

        # LoRA 스타일 decay 변조
        decay_lora = self.decay_lora_a(x)
        decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b
        time_decay = time_decay + decay_lora

RWKV-7 (Goose) — 2025년 3월

# RWKV-7: State Transition Matrix
# 상태 전이를 행렬로 표현하여 더 풍부한 상태 업데이트
class RWKV7_TimeMix:
    """
    RWKV-7 핵심 혁신:
    1. State Transition Matrix: 스칼라가 아닌 행렬로 상태 전이
    2. In-Context Learning 강화: 동적으로 학습 규칙 조정
    3. Improved Token Mixing: 더 정교한 토큰 간 정보 흐름
    """

    def forward(self, x, state):
        # 상태 전이 행렬 계산
        a = torch.sigmoid(self.W_a(x))  # (B, T, D)
        b = self.W_b(x)                 # (B, T, D)

        # 행렬 형태의 상태 업데이트
        # s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t
        k = self.W_k(x)
        v = self.W_v(x)

        # 상태: (B, H, D/H, D/H) 행렬!
        for t in range(T):
            state = torch.diag(a[:, t]) @ state + \
                    k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)

        return output, state

Transformer와 성능 비교

모델 크기별 성능 (Pile dataset perplexity, ↓ 낮을수록 좋음):

| 모델 크기  | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |
|-----------|-------------|--------|--------|--------|
| 169M      | 17.2        | 18.1   | 17.4   | 17.0   |
| 430M      | 13.8        | 14.5   | 13.9   | 13.6   |
| 1.5B      | 11.2        | 11.8   | 11.3   | 11.0   |
| 3B        | 9.8         | 10.3   | 9.9    | 9.6    |
| 7B        | 8.5         | 9.0    | 8.6    | 8.3    |
| 14B       | 7.8         | 8.2    | 7.9    | 7.6    |

추론 효율 (토큰/, A100 GPU):

| 시퀀스 길이 | Transformer | RWKV-7  |
|-----------|-------------|---------|
| 1K        | 1000        | 1200    |
| 4K        | 800         | 1200    |
| 16K       | 300         | 1200    |
| 64K       | 50          | 1200    |
| 128K+     | OOM         | 1200    |

실전: RWKV 사용하기

HuggingFace에서 사용

from transformers import AutoModelForCausalLM, AutoTokenizer

# RWKV-7 모델 로드
model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True
)

# 텍스트 생성
prompt = "Kubernetes에서 Pod 오토스케일링을 구현하려면"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens=200,
    temperature=0.7,
    top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

RWKV Runner로 로컬 실행

# RWKV Runner 설치 (GUI 도구)
git clone https://github.com/josStorer/RWKV-Runner
cd RWKV-Runner

# 또는 ChatRWKV (CLI)
pip install rwkv

# Python에서 직접 사용
python3 << 'EOF'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

result = pipeline.generate(
    "한국의 AI 산업 현황에 대해 설명해주세요.",
    token_count=200,
    temperature=0.8
)
print(result)
EOF

RWKV vs Mamba vs Transformer

| 특성           | Transformer    | Mamba          | RWKV-7        |
|---------------|----------------|----------------|---------------|
| 학습 복잡도     | O(N²)         | O(N)           | O(N)          |
| 추론 복잡도     | O(N) per token | O(1) per token | O(1) per token|
| 메모리          | O(N) KV Cache  | O(1) 상태     | O(1) 상태     |
| 병렬 학습       ||  (scan)      |  (WKV)      |
| 장기 의존성      | ✅ 강함        | ⭕ 양호        | ⭕ 양호       |
| In-Context Learning | ✅ 강함   | ⭕ 양호        | ✅ v7에서 강화 |
| 구현 복잡도     | 낮음           | 중간 (CUDA)    | 중간 (CUDA)   |
| 커뮤니티       | 거대           | 성장 중        | 활발          |

한계와 미래

현재 한계

  1. 복잡한 검색 태스크: Transformer의 Full Attention 대비 특정 패턴 검색에 약함
  2. CUDA 커널 의존성: 최적 성능을 위해 커스텀 CUDA 커널 필요
  3. 생태계: Transformer 대비 도구/라이브러리 부족

미래 방향

  • 하이브리드 아키텍처: RWKV + 소량 Attention 결합
  • 하드웨어 최적화: Groq, Cerebras 등 새로운 칩에 최적화
  • 멀티모달: 비전, 오디오 등 다른 모달리티 확장

퀴즈

Q1. RWKV에서 R, W, K, V는 각각 무엇을 의미하나요?

R: Receptance (수용 게이트), W: Weight (시간 감쇠), K: Key, V: Value입니다.

Q2. RWKV가 학습과 추론에서 각각 어떤 모드로 동작하나요?

학습 시에는 Transformer처럼 병렬(parallel) 모드로 전체 시퀀스를 한번에 처리하고, 추론 시에는 RNN처럼 순차(sequential) 모드로 토큰당 O(1) 비용으로 동작합니다.

Q3. RWKV-6에서 도입된 Data-dependent time decay란?

기존의 고정된 감쇠율 대신, 입력 데이터에 따라 동적으로 감쇠율이 변하는 메커니즘입니다. Mamba의 선택적(Selective) 메커니즘과 유사한 아이디어입니다.

Q4. RWKV-7 Goose의 핵심 혁신은?

State Transition Matrix를 도입하여 상태 전이를 스칼라가 아닌 행렬로 표현합니다. 이를 통해 더 풍부한 상태 업데이트와 강화된 In-Context Learning이 가능합니다.

Q5. Transformer 대비 RWKV의 가장 큰 장점은?

시퀀스 길이에 관계없이 추론 비용이 일정(O(1) per token)합니다. 128K+ 토큰에서도 Transformer처럼 OOM이 발생하지 않습니다.

Q6. Token Shift 메커니즘의 역할은?

현재 토큰과 이전 토큰을 가중 평균으로 혼합하여, 위치 인코딩 없이도 시퀀스 내 위치 정보를 전달합니다.

Q7. RWKV의 현재 한계 중 하나는?

Transformer의 Full Self-Attention 대비 복잡한 패턴 검색이나 정확한 정보 검색 태스크에서 성능이 떨어질 수 있습니다.

마무리

RWKV는 "RNN은 죽었다"는 통념을 뒤집는 혁신적 아키텍처입니다. Transformer의 병렬 학습 효율성과 RNN의 추론 효율성을 결합하여, 특히 긴 시퀀스 처리와 엣지 디바이스 배포에서 강점을 보입니다. v7 Goose에서 State Transition Matrix 도입으로 Transformer와의 성능 격차가 거의 해소되었습니다.

참고 자료