- Published on
RWKV: Reinventing RNNs for the Transformer Era — v4에서 v7 Goose까지
- Authors
- Name
- 논문 개요
- 동기: Transformer vs RNN의 딜레마
- RWKV 핵심 아키텍처
- RWKV 블록 구조
- 버전별 진화
- Transformer와 성능 비교
- 실전: RWKV 사용하기
- RWKV vs Mamba vs Transformer
- 한계와 미래
- 퀴즈
- 마무리
- 참고 자료

논문 개요
- 제목: RWKV: Reinventing RNNs for the Transformer Era
- 저자: Bo Peng 외 (RWKV Foundation)
- 최초 발표: 2023년 5월 (arXiv: 2305.13048), EMNLP 2023
- 최신 버전: RWKV-7 "Goose" (2025년 3월 발표)
- 코드: github.com/BlinkDL/RWKV-LM
동기: Transformer vs RNN의 딜레마
현대 LLM은 거의 모두 Transformer 기반이지만 근본적 한계가 있습니다:
- O(N²) Self-Attention: 시퀀스 길이에 대해 이차 복잡도
- KV Cache 폭발: 추론 시 메모리가 시퀀스 길이에 비례
- 긴 컨텍스트 비용: 128K+ 컨텍스트에서 비용/메모리 급증
반면 전통적 RNN은:
- O(N) 복잡도: 선형 시간/메모리
- 고정 상태: 일정한 추론 비용
- BUT: 학습이 병렬화되지 않아 느림, 장기 의존성 포착 취약
RWKV의 질문: "Transformer처럼 병렬 학습하면서 RNN처럼 효율적으로 추론할 수는 없을까?"
RWKV 핵심 아키텍처
RWKV는 이름 자체가 아키텍처의 핵심 요소입니다:
- R: Receptance (수용 게이트) — 과거 정보 수용 정도 결정
- W: Weight (시간 감쇠) — 과거 정보의 감쇠율
- K: Key — 현재 입력의 키
- V: Value — 현재 입력의 값
WKV (Weighted Key-Value) 메커니즘
RWKV의 핵심인 WKV 연산은 다음과 같습니다:
import torch
def wkv_vanilla(w, u, k, v):
"""
RWKV의 WKV 메커니즘 (순수 Python 구현)
w: time decay (음수, 클수록 빠르게 감쇠)
u: bonus (현재 토큰 보너스)
k: key
v: value
"""
T, C = k.shape
output = torch.zeros_like(v)
for c in range(C):
# 채널별 독립 처리 (O(T) per channel)
a = 0.0 # 누적 분자
b = 0.0 # 누적 분모
p = -1e30 # 최대값 (수치 안정성)
for t in range(T):
# 현재 토큰의 기여
e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))
e2 = torch.exp(torch.clamp(w[c] + p - p, max=30)) # 이전 누적 감쇠
# WKV 계산
wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)
output[t, c] = wkv
# 상태 업데이트 (RNN 방식)
new_p = max(w[c] + p, k[t, c])
e1 = torch.exp(k[t, c] - new_p)
e2 = torch.exp(w[c] + p - new_p)
a = e2 * a + e1 * v[t, c]
b = e2 * b + e1
p = new_p
return output
듀얼 모드: Transformer 모드 vs RNN 모드
# 학습 시: Transformer 모드 (병렬)
# 전체 시퀀스를 한번에 처리, O(T) 복잡도
def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):
T, C = x.shape
k = k_proj(x) # (T, C)
v = v_proj(x) # (T, C)
r = torch.sigmoid(r_proj(x)) # (T, C) 수용 게이트
# 병렬 WKV 연산 (CUDA 커널)
wkv = parallel_wkv_cuda(w, u, k, v)
return r * wkv # 수용 게이트 적용
# 추론 시: RNN 모드 (순차, 일정 메모리)
def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):
"""한 토큰씩 처리, O(1) 복잡도"""
k = k_proj(x_t)
v = v_proj(x_t)
r = torch.sigmoid(r_proj(x_t))
# state = (a, b, p) — 고정 크기!
a, b, p = state
e1 = torch.exp(u + k - p)
e2 = torch.exp(w + p - p)
wkv = (e1 * v + e2 * a) / (e1 + e2 * b)
output = r * wkv
# 상태 업데이트
new_p = torch.maximum(w + p, k)
e1 = torch.exp(k - new_p)
e2 = torch.exp(w + p - new_p)
new_a = e2 * a + e1 * v
new_b = e2 * b + e1
new_state = (new_a, new_b, new_p)
return output, new_state
RWKV 블록 구조
class RWKVBlock:
"""
RWKV의 기본 블록 구조
Transformer Block과 비교:
Transformer: LayerNorm → Attention → Add → LayerNorm → FFN → Add
RWKV: LayerNorm → TimeMix → Add → LayerNorm → ChannelMix → Add
"""
def __init__(self, dim, layer_id):
# Time Mixing (Attention 대체)
self.time_mix = TimeMixing(dim, layer_id)
# Channel Mixing (FFN 대체)
self.channel_mix = ChannelMixing(dim, layer_id)
self.ln1 = LayerNorm(dim)
self.ln2 = LayerNorm(dim)
def forward(self, x, state):
# Time Mixing (과거 토큰 정보 혼합)
dx, state = self.time_mix(self.ln1(x), state)
x = x + dx
# Channel Mixing (채널 간 정보 혼합)
dx = self.channel_mix(self.ln2(x))
x = x + dx
return x, state
class TimeMixing:
"""
Token Shift + WKV
현재 토큰과 이전 토큰을 선형 보간하여 사용
"""
def __init__(self, dim, layer_id):
self.mix_r = nn.Parameter(torch.ones(dim)) # 보간 비율
self.mix_k = nn.Parameter(torch.ones(dim))
self.mix_v = nn.Parameter(torch.ones(dim))
def forward(self, x, state):
# Token Shift: 현재 토큰과 이전 토큰의 가중 평균
x_prev = state.shift # 이전 토큰
xr = x * self.mix_r + x_prev * (1 - self.mix_r)
xk = x * self.mix_k + x_prev * (1 - self.mix_k)
xv = x * self.mix_v + x_prev * (1 - self.mix_v)
r = torch.sigmoid(self.W_r(xr))
k = self.W_k(xk)
v = self.W_v(xv)
wkv = compute_wkv(k, v, state)
return r * wkv, new_state
버전별 진화
RWKV-4 (Eagle)
- 기본 WKV 메커니즘 도입
- Token Shift로 위치 인코딩 대체
- 최대 14B 파라미터
RWKV-5 (Eagle)
# RWKV-5: Multi-headed State
# 여러 개의 독립적인 상태를 유지하여 표현력 향상
class RWKV5_TimeMix:
def __init__(self, dim, n_heads=8):
self.n_heads = n_heads
self.head_dim = dim // n_heads
# 각 헤드가 독립적인 decay rate 보유
self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))
RWKV-6 (Finch)
# RWKV-6: Data-dependent time decay
# 입력에 따라 감쇠율이 동적으로 변화 (Mamba의 선택적 메커니즘과 유사)
class RWKV6_TimeMix:
def forward(self, x, state):
# 고정 decay가 아닌, 입력 의존적 decay
time_decay = self.W_decay(x) # 입력마다 다른 감쇠율!
time_decay = torch.exp(-torch.exp(time_decay))
# LoRA 스타일 decay 변조
decay_lora = self.decay_lora_a(x)
decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b
time_decay = time_decay + decay_lora
RWKV-7 (Goose) — 2025년 3월
# RWKV-7: State Transition Matrix
# 상태 전이를 행렬로 표현하여 더 풍부한 상태 업데이트
class RWKV7_TimeMix:
"""
RWKV-7 핵심 혁신:
1. State Transition Matrix: 스칼라가 아닌 행렬로 상태 전이
2. In-Context Learning 강화: 동적으로 학습 규칙 조정
3. Improved Token Mixing: 더 정교한 토큰 간 정보 흐름
"""
def forward(self, x, state):
# 상태 전이 행렬 계산
a = torch.sigmoid(self.W_a(x)) # (B, T, D)
b = self.W_b(x) # (B, T, D)
# 행렬 형태의 상태 업데이트
# s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t
k = self.W_k(x)
v = self.W_v(x)
# 상태: (B, H, D/H, D/H) 행렬!
for t in range(T):
state = torch.diag(a[:, t]) @ state + \
k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)
return output, state
Transformer와 성능 비교
모델 크기별 성능 (Pile dataset perplexity, ↓ 낮을수록 좋음):
| 모델 크기 | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |
|-----------|-------------|--------|--------|--------|
| 169M | 17.2 | 18.1 | 17.4 | 17.0 |
| 430M | 13.8 | 14.5 | 13.9 | 13.6 |
| 1.5B | 11.2 | 11.8 | 11.3 | 11.0 |
| 3B | 9.8 | 10.3 | 9.9 | 9.6 |
| 7B | 8.5 | 9.0 | 8.6 | 8.3 |
| 14B | 7.8 | 8.2 | 7.9 | 7.6 |
추론 효율 (토큰/초, A100 GPU):
| 시퀀스 길이 | Transformer | RWKV-7 |
|-----------|-------------|---------|
| 1K | 1000 | 1200 |
| 4K | 800 | 1200 |
| 16K | 300 | 1200 |
| 64K | 50 | 1200 |
| 128K+ | OOM | 1200 |
실전: RWKV 사용하기
HuggingFace에서 사용
from transformers import AutoModelForCausalLM, AutoTokenizer
# RWKV-7 모델 로드
model = AutoModelForCausalLM.from_pretrained(
"RWKV/rwkv-7-world-3b",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"RWKV/rwkv-7-world-3b",
trust_remote_code=True
)
# 텍스트 생성
prompt = "Kubernetes에서 Pod 오토스케일링을 구현하려면"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
RWKV Runner로 로컬 실행
# RWKV Runner 설치 (GUI 도구)
git clone https://github.com/josStorer/RWKV-Runner
cd RWKV-Runner
# 또는 ChatRWKV (CLI)
pip install rwkv
# Python에서 직접 사용
python3 << 'EOF'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
result = pipeline.generate(
"한국의 AI 산업 현황에 대해 설명해주세요.",
token_count=200,
temperature=0.8
)
print(result)
EOF
RWKV vs Mamba vs Transformer
| 특성 | Transformer | Mamba | RWKV-7 |
|---------------|----------------|----------------|---------------|
| 학습 복잡도 | O(N²) | O(N) | O(N) |
| 추론 복잡도 | O(N) per token | O(1) per token | O(1) per token|
| 메모리 | O(N) KV Cache | O(1) 상태 | O(1) 상태 |
| 병렬 학습 | ✅ | ✅ (scan) | ✅ (WKV) |
| 장기 의존성 | ✅ 강함 | ⭕ 양호 | ⭕ 양호 |
| In-Context Learning | ✅ 강함 | ⭕ 양호 | ✅ v7에서 강화 |
| 구현 복잡도 | 낮음 | 중간 (CUDA) | 중간 (CUDA) |
| 커뮤니티 | 거대 | 성장 중 | 활발 |
한계와 미래
현재 한계
- 복잡한 검색 태스크: Transformer의 Full Attention 대비 특정 패턴 검색에 약함
- CUDA 커널 의존성: 최적 성능을 위해 커스텀 CUDA 커널 필요
- 생태계: Transformer 대비 도구/라이브러리 부족
미래 방향
- 하이브리드 아키텍처: RWKV + 소량 Attention 결합
- 하드웨어 최적화: Groq, Cerebras 등 새로운 칩에 최적화
- 멀티모달: 비전, 오디오 등 다른 모달리티 확장
퀴즈
Q1. RWKV에서 R, W, K, V는 각각 무엇을 의미하나요?
R: Receptance (수용 게이트), W: Weight (시간 감쇠), K: Key, V: Value입니다.
Q2. RWKV가 학습과 추론에서 각각 어떤 모드로 동작하나요?
학습 시에는 Transformer처럼 병렬(parallel) 모드로 전체 시퀀스를 한번에 처리하고, 추론 시에는 RNN처럼 순차(sequential) 모드로 토큰당 O(1) 비용으로 동작합니다.
Q3. RWKV-6에서 도입된 Data-dependent time decay란?
기존의 고정된 감쇠율 대신, 입력 데이터에 따라 동적으로 감쇠율이 변하는 메커니즘입니다. Mamba의 선택적(Selective) 메커니즘과 유사한 아이디어입니다.
Q4. RWKV-7 Goose의 핵심 혁신은?
State Transition Matrix를 도입하여 상태 전이를 스칼라가 아닌 행렬로 표현합니다. 이를 통해 더 풍부한 상태 업데이트와 강화된 In-Context Learning이 가능합니다.
Q5. Transformer 대비 RWKV의 가장 큰 장점은?
시퀀스 길이에 관계없이 추론 비용이 일정(O(1) per token)합니다. 128K+ 토큰에서도 Transformer처럼 OOM이 발생하지 않습니다.
Q6. Token Shift 메커니즘의 역할은?
현재 토큰과 이전 토큰을 가중 평균으로 혼합하여, 위치 인코딩 없이도 시퀀스 내 위치 정보를 전달합니다.
Q7. RWKV의 현재 한계 중 하나는?
Transformer의 Full Self-Attention 대비 복잡한 패턴 검색이나 정확한 정보 검색 태스크에서 성능이 떨어질 수 있습니다.
마무리
RWKV는 "RNN은 죽었다"는 통념을 뒤집는 혁신적 아키텍처입니다. Transformer의 병렬 학습 효율성과 RNN의 추론 효율성을 결합하여, 특히 긴 시퀀스 처리와 엣지 디바이스 배포에서 강점을 보입니다. v7 Goose에서 State Transition Matrix 도입으로 Transformer와의 성능 격차가 거의 해소되었습니다.