- Authors
- Name
- 論文概要
- 動機:Transformer vs RNNのジレンマ
- RWKVコアアーキテクチャ
- RWKVブロック構造
- バージョンごとの進化
- Transformerとの性能比較
- 実践:RWKVを使う
- RWKV vs Mamba vs Transformer
- 限界と今後の展望
- クイズ
- まとめ
- 参考資料

論文概要
- タイトル: RWKV: Reinventing RNNs for the Transformer Era
- 著者: Bo Peng 他 (RWKV Foundation)
- 初回発表: 2023年5月 (arXiv: 2305.13048)、EMNLP 2023
- 最新バージョン: RWKV-7 "Goose"(2025年3月発表)
- コード: github.com/BlinkDL/RWKV-LM
動機:Transformer vs RNNのジレンマ
現代のLLMはほぼすべてTransformerベースですが、根本的な限界があります:
- O(N²) Self-Attention: シーケンス長に対して二次計算量
- KV Cacheの爆発: 推論時にメモリがシーケンス長に比例
- 長コンテキストのコスト: 128K以上のコンテキストでコスト・メモリが急増
一方、伝統的なRNNは:
- O(N) 計算量: 線形時間・メモリ
- 固定状態: 一定の推論コスト
- しかし: 学習が並列化できず遅い、長距離依存の捕捉が弱い
RWKVの問い: 「Transformerのように並列学習しながら、RNNのように効率的に推論することはできないか?」
RWKVコアアーキテクチャ
RWKVという名前自体がアーキテクチャの核心要素を示しています:
- R: Receptance(受容ゲート)— 過去の情報をどの程度受け入れるかを決定
- W: Weight(時間減衰)— 過去の情報の減衰率
- K: Key — 現在の入力のキー
- V: Value — 現在の入力の値
WKV(Weighted Key-Value)メカニズム
RWKVの核心であるWKV演算は以下の通りです:
import torch
def wkv_vanilla(w, u, k, v):
"""
RWKVのWKVメカニズム(純粋Python実装)
w: time decay(負値、絶対値が大きいほど速く減衰)
u: bonus(現在トークンのボーナス)
k: key
v: value
"""
T, C = k.shape
output = torch.zeros_like(v)
for c in range(C):
# チャネルごとに独立処理(O(T) per channel)
a = 0.0 # 累積分子
b = 0.0 # 累積分母
p = -1e30 # 最大値(数値安定性)
for t in range(T):
# 現在トークンの寄与
e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))
e2 = torch.exp(torch.clamp(w[c] + p - p, max=30)) # 前の累積の減衰
# WKV計算
wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)
output[t, c] = wkv
# 状態更新(RNN方式)
new_p = max(w[c] + p, k[t, c])
e1 = torch.exp(k[t, c] - new_p)
e2 = torch.exp(w[c] + p - new_p)
a = e2 * a + e1 * v[t, c]
b = e2 * b + e1
p = new_p
return output
デュアルモード:Transformerモード vs RNNモード
# 学習時:Transformerモード(並列)
# シーケンス全体を一度に処理、O(T)計算量
def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):
T, C = x.shape
k = k_proj(x) # (T, C)
v = v_proj(x) # (T, C)
r = torch.sigmoid(r_proj(x)) # (T, C) 受容ゲート
# 並列WKV演算(CUDAカーネル)
wkv = parallel_wkv_cuda(w, u, k, v)
return r * wkv # 受容ゲート適用
# 推論時:RNNモード(逐次、一定メモリ)
def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):
"""1トークンずつ処理、O(1)計算量"""
k = k_proj(x_t)
v = v_proj(x_t)
r = torch.sigmoid(r_proj(x_t))
# state = (a, b, p) — 固定サイズ!
a, b, p = state
e1 = torch.exp(u + k - p)
e2 = torch.exp(w + p - p)
wkv = (e1 * v + e2 * a) / (e1 + e2 * b)
output = r * wkv
# 状態更新
new_p = torch.maximum(w + p, k)
e1 = torch.exp(k - new_p)
e2 = torch.exp(w + p - new_p)
new_a = e2 * a + e1 * v
new_b = e2 * b + e1
new_state = (new_a, new_b, new_p)
return output, new_state
RWKVブロック構造
class RWKVBlock:
"""
RWKVの基本ブロック構造
Transformer Blockとの比較:
Transformer: LayerNorm → Attention → Add → LayerNorm → FFN → Add
RWKV: LayerNorm → TimeMix → Add → LayerNorm → ChannelMix → Add
"""
def __init__(self, dim, layer_id):
# Time Mixing(Attentionの代替)
self.time_mix = TimeMixing(dim, layer_id)
# Channel Mixing(FFNの代替)
self.channel_mix = ChannelMixing(dim, layer_id)
self.ln1 = LayerNorm(dim)
self.ln2 = LayerNorm(dim)
def forward(self, x, state):
# Time Mixing(過去のトークン情報を混合)
dx, state = self.time_mix(self.ln1(x), state)
x = x + dx
# Channel Mixing(チャネル間の情報を混合)
dx = self.channel_mix(self.ln2(x))
x = x + dx
return x, state
class TimeMixing:
"""
Token Shift + WKV
現在のトークンと前のトークンを線形補間して使用
"""
def __init__(self, dim, layer_id):
self.mix_r = nn.Parameter(torch.ones(dim)) # 補間比率
self.mix_k = nn.Parameter(torch.ones(dim))
self.mix_v = nn.Parameter(torch.ones(dim))
def forward(self, x, state):
# Token Shift:現在と前のトークンの加重平均
x_prev = state.shift # 前のトークン
xr = x * self.mix_r + x_prev * (1 - self.mix_r)
xk = x * self.mix_k + x_prev * (1 - self.mix_k)
xv = x * self.mix_v + x_prev * (1 - self.mix_v)
r = torch.sigmoid(self.W_r(xr))
k = self.W_k(xk)
v = self.W_v(xv)
wkv = compute_wkv(k, v, state)
return r * wkv, new_state
バージョンごとの進化
RWKV-4 (Eagle)
- 基本的なWKVメカニズムを導入
- Token Shiftで位置エンコーディングを代替
- 最大14Bパラメータ
RWKV-5 (Eagle)
# RWKV-5: Multi-headed State
# 複数の独立した状態を保持し表現力を向上
class RWKV5_TimeMix:
def __init__(self, dim, n_heads=8):
self.n_heads = n_heads
self.head_dim = dim // n_heads
# 各ヘッドが独立したdecay rateを保持
self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))
RWKV-6 (Finch)
# RWKV-6: Data-dependent time decay
# 入力に応じて減衰率が動的に変化(Mambaの選択的メカニズムと類似)
class RWKV6_TimeMix:
def forward(self, x, state):
# 固定decayではなく、入力依存のdecay
time_decay = self.W_decay(x) # 入力ごとに異なる減衰率!
time_decay = torch.exp(-torch.exp(time_decay))
# LoRAスタイルのdecay変調
decay_lora = self.decay_lora_a(x)
decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b
time_decay = time_decay + decay_lora
RWKV-7 (Goose) — 2025年3月
# RWKV-7: State Transition Matrix
# 状態遷移を行列で表現し、より豊かな状態更新を実現
class RWKV7_TimeMix:
"""
RWKV-7の核心的革新:
1. State Transition Matrix:スカラーではなく行列で状態遷移
2. In-Context Learningの強化:動的に学習ルールを調整
3. Improved Token Mixing:より精巧なトークン間情報フロー
"""
def forward(self, x, state):
# 状態遷移行列の計算
a = torch.sigmoid(self.W_a(x)) # (B, T, D)
b = self.W_b(x) # (B, T, D)
# 行列形式の状態更新
# s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t
k = self.W_k(x)
v = self.W_v(x)
# 状態:(B, H, D/H, D/H) 行列!
for t in range(T):
state = torch.diag(a[:, t]) @ state + \
k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)
return output, state
Transformerとの性能比較
モデルサイズ別性能(Pile datasetパープレキシティ、低いほど良い):
| モデルサイズ | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |
|-------------|-------------|--------|--------|--------|
| 169M | 17.2 | 18.1 | 17.4 | 17.0 |
| 430M | 13.8 | 14.5 | 13.9 | 13.6 |
| 1.5B | 11.2 | 11.8 | 11.3 | 11.0 |
| 3B | 9.8 | 10.3 | 9.9 | 9.6 |
| 7B | 8.5 | 9.0 | 8.6 | 8.3 |
| 14B | 7.8 | 8.2 | 7.9 | 7.6 |
推論効率(トークン/秒、A100 GPU):
| シーケンス長 | Transformer | RWKV-7 |
|-------------|-------------|---------|
| 1K | 1000 | 1200 |
| 4K | 800 | 1200 |
| 16K | 300 | 1200 |
| 64K | 50 | 1200 |
| 128K+ | OOM | 1200 |
実践:RWKVを使う
HuggingFaceでの使用
from transformers import AutoModelForCausalLM, AutoTokenizer
# RWKV-7モデルのロード
model = AutoModelForCausalLM.from_pretrained(
"RWKV/rwkv-7-world-3b",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"RWKV/rwkv-7-world-3b",
trust_remote_code=True
)
# テキスト生成
prompt = "KubernetesでPodオートスケーリングを実装するには"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
RWKV Runnerでローカル実行
# RWKV Runnerのインストール(GUIツール)
git clone https://github.com/josStorer/RWKV-Runner
cd RWKV-Runner
# またはChatRWKV(CLI)
pip install rwkv
# Pythonから直接使用
python3 << 'EOF'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
result = pipeline.generate(
"韓国のAI産業の現状について説明してください。",
token_count=200,
temperature=0.8
)
print(result)
EOF
RWKV vs Mamba vs Transformer
| 特性 | Transformer | Mamba | RWKV-7 |
|---------------------|----------------|----------------|---------------|
| 学習計算量 | O(N²) | O(N) | O(N) |
| 推論計算量 | O(N) per token | O(1) per token | O(1) per token|
| メモリ | O(N) KV Cache | O(1) 状態 | O(1) 状態 |
| 並列学習 | 可能 | 可能 (scan) | 可能 (WKV) |
| 長距離依存 | 強い | 良好 | 良好 |
| In-Context Learning | 強い | 良好 | v7で強化 |
| 実装の複雑さ | 低い | 中程度 (CUDA) | 中程度 (CUDA) |
| コミュニティ | 巨大 | 成長中 | 活発 |
限界と今後の展望
現在の限界
- 複雑な検索タスク: TransformerのFull Attentionと比較して、特定パターン検索で劣る場合がある
- CUDAカーネル依存: 最適な性能を得るにはカスタムCUDAカーネルが必要
- エコシステム: Transformerに比べてツールやライブラリが不足
今後の方向性
- ハイブリッドアーキテクチャ: RWKVに少量のAttentionを組み合わせる
- ハードウェア最適化: Groq、Cerebrasなど新しいチップへの最適化
- マルチモーダル: ビジョン、オーディオなど他のモダリティへの拡張
クイズ
Q1. RWKVのR、W、K、Vはそれぞれ何を意味しますか?
R: Receptance(受容ゲート)、W: Weight(時間減衰)、K: Key、V: Valueです。
Q2. RWKVは学習時と推論時にそれぞれどのモードで動作しますか?
学習時はTransformerのように並列(parallel)モードでシーケンス全体を一度に処理し、推論時はRNNのように逐次(sequential)モードでトークンあたりO(1)コストで動作します。
Q3. RWKV-6で導入されたData-dependent time decayとは?
固定された減衰率の代わりに、入力データに応じて動的に減衰率が変化するメカニズムです。Mambaの選択的(Selective)メカニズムと類似したアイデアです。
Q4. RWKV-7 Gooseの核心的革新は?
State Transition Matrixを導入し、状態遷移をスカラーではなく行列で表現します。これにより、より豊かな状態更新と強化されたIn-Context Learningが可能になります。
Q5. Transformerに対するRWKVの最大の利点は?
シーケンス長に関係なく推論コストが一定(O(1) per token)です。128K以上のトークンでもTransformerのようにOOMが発生しません。
Q6. Token Shiftメカニズムの役割は?
現在のトークンと前のトークンを加重平均で混合し、位置エンコーディングなしでもシーケンス内の位置情報を伝達します。
Q7. RWKVの現在の限界の1つは?
TransformerのFull Self-Attentionと比較して、複雑なパターン検索や正確な情報検索タスクで性能が劣る場合があります。
まとめ
RWKVは「RNNは死んだ」という定説を覆す革新的アーキテクチャです。Transformerの並列学習効率とRNNの推論効率を融合し、特に長いシーケンス処理やエッジデバイスへの展開で強みを発揮します。v7 GooseでState Transition Matrixが導入され、Transformerとの性能差がほぼ解消されました。