Skip to content

필사 모드: RWKV: Transformer時代のRNN再発明 — v4からv7 Gooseまで

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

論文概要

- **タイトル**: RWKV: Reinventing RNNs for the Transformer Era

- **著者**: Bo Peng 他 (RWKV Foundation)

- **初回発表**: 2023年5月 (arXiv: 2305.13048)、EMNLP 2023

- **最新バージョン**: RWKV-7 "Goose"(2025年3月発表)

- **コード**: [github.com/BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM)

動機:Transformer vs RNNのジレンマ

現代のLLMはほぼすべてTransformerベースですが、根本的な限界があります:

- **O(N²) Self-Attention**: シーケンス長に対して二次計算量

- **KV Cacheの爆発**: 推論時にメモリがシーケンス長に比例

- **長コンテキストのコスト**: 128K以上のコンテキストでコスト・メモリが急増

一方、伝統的なRNNは:

- **O(N) 計算量**: 線形時間・メモリ

- **固定状態**: 一定の推論コスト

- **しかし**: 学習が並列化できず遅い、長距離依存の捕捉が弱い

**RWKVの問い**: 「Transformerのように並列学習しながら、RNNのように効率的に推論することはできないか?」

RWKVコアアーキテクチャ

RWKVという名前自体がアーキテクチャの核心要素を示しています:

- **R**: Receptance(受容ゲート)— 過去の情報をどの程度受け入れるかを決定

- **W**: Weight(時間減衰)— 過去の情報の減衰率

- **K**: Key — 現在の入力のキー

- **V**: Value — 現在の入力の値

WKV(Weighted Key-Value)メカニズム

RWKVの核心であるWKV演算は以下の通りです:

def wkv_vanilla(w, u, k, v):

"""

RWKVのWKVメカニズム(純粋Python実装)

w: time decay(負値、絶対値が大きいほど速く減衰)

u: bonus(現在トークンのボーナス)

k: key

v: value

"""

T, C = k.shape

output = torch.zeros_like(v)

for c in range(C):

チャネルごとに独立処理(O(T) per channel)

a = 0.0 # 累積分子

b = 0.0 # 累積分母

p = -1e30 # 最大値(数値安定性)

for t in range(T):

現在トークンの寄与

e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))

e2 = torch.exp(torch.clamp(w[c] + p - p, max=30)) # 前の累積の減衰

WKV計算

wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)

output[t, c] = wkv

状態更新(RNN方式)

new_p = max(w[c] + p, k[t, c])

e1 = torch.exp(k[t, c] - new_p)

e2 = torch.exp(w[c] + p - new_p)

a = e2 * a + e1 * v[t, c]

b = e2 * b + e1

p = new_p

return output

デュアルモード:Transformerモード vs RNNモード

学習時:Transformerモード(並列)

シーケンス全体を一度に処理、O(T)計算量

def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):

T, C = x.shape

k = k_proj(x) # (T, C)

v = v_proj(x) # (T, C)

r = torch.sigmoid(r_proj(x)) # (T, C) 受容ゲート

並列WKV演算(CUDAカーネル)

wkv = parallel_wkv_cuda(w, u, k, v)

return r * wkv # 受容ゲート適用

推論時:RNNモード(逐次、一定メモリ)

def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):

"""1トークンずつ処理、O(1)計算量"""

k = k_proj(x_t)

v = v_proj(x_t)

r = torch.sigmoid(r_proj(x_t))

state = (a, b, p) — 固定サイズ!

a, b, p = state

e1 = torch.exp(u + k - p)

e2 = torch.exp(w + p - p)

wkv = (e1 * v + e2 * a) / (e1 + e2 * b)

output = r * wkv

状態更新

new_p = torch.maximum(w + p, k)

e1 = torch.exp(k - new_p)

e2 = torch.exp(w + p - new_p)

new_a = e2 * a + e1 * v

new_b = e2 * b + e1

new_state = (new_a, new_b, new_p)

return output, new_state

RWKVブロック構造

class RWKVBlock:

"""

RWKVの基本ブロック構造

Transformer Blockとの比較:

Transformer: LayerNorm → Attention → Add → LayerNorm → FFN → Add

RWKV: LayerNorm → TimeMix → Add → LayerNorm → ChannelMix → Add

"""

def __init__(self, dim, layer_id):

Time Mixing(Attentionの代替)

self.time_mix = TimeMixing(dim, layer_id)

Channel Mixing(FFNの代替)

self.channel_mix = ChannelMixing(dim, layer_id)

self.ln1 = LayerNorm(dim)

self.ln2 = LayerNorm(dim)

def forward(self, x, state):

Time Mixing(過去のトークン情報を混合)

dx, state = self.time_mix(self.ln1(x), state)

x = x + dx

Channel Mixing(チャネル間の情報を混合)

dx = self.channel_mix(self.ln2(x))

x = x + dx

return x, state

class TimeMixing:

"""

Token Shift + WKV

現在のトークンと前のトークンを線形補間して使用

"""

def __init__(self, dim, layer_id):

self.mix_r = nn.Parameter(torch.ones(dim)) # 補間比率

self.mix_k = nn.Parameter(torch.ones(dim))

self.mix_v = nn.Parameter(torch.ones(dim))

def forward(self, x, state):

Token Shift:現在と前のトークンの加重平均

x_prev = state.shift # 前のトークン

xr = x * self.mix_r + x_prev * (1 - self.mix_r)

xk = x * self.mix_k + x_prev * (1 - self.mix_k)

xv = x * self.mix_v + x_prev * (1 - self.mix_v)

r = torch.sigmoid(self.W_r(xr))

k = self.W_k(xk)

v = self.W_v(xv)

wkv = compute_wkv(k, v, state)

return r * wkv, new_state

バージョンごとの進化

RWKV-4 (Eagle)

- 基本的なWKVメカニズムを導入

- Token Shiftで位置エンコーディングを代替

- 最大14Bパラメータ

RWKV-5 (Eagle)

RWKV-5: Multi-headed State

複数の独立した状態を保持し表現力を向上

class RWKV5_TimeMix:

def __init__(self, dim, n_heads=8):

self.n_heads = n_heads

self.head_dim = dim // n_heads

各ヘッドが独立したdecay rateを保持

self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))

RWKV-6 (Finch)

RWKV-6: Data-dependent time decay

入力に応じて減衰率が動的に変化(Mambaの選択的メカニズムと類似)

class RWKV6_TimeMix:

def forward(self, x, state):

固定decayではなく、入力依存のdecay

time_decay = self.W_decay(x) # 入力ごとに異なる減衰率!

time_decay = torch.exp(-torch.exp(time_decay))

LoRAスタイルのdecay変調

decay_lora = self.decay_lora_a(x)

decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b

time_decay = time_decay + decay_lora

RWKV-7 (Goose) — 2025年3月

RWKV-7: State Transition Matrix

状態遷移を行列で表現し、より豊かな状態更新を実現

class RWKV7_TimeMix:

"""

RWKV-7の核心的革新:

1. State Transition Matrix:スカラーではなく行列で状態遷移

2. In-Context Learningの強化:動的に学習ルールを調整

3. Improved Token Mixing:より精巧なトークン間情報フロー

"""

def forward(self, x, state):

状態遷移行列の計算

a = torch.sigmoid(self.W_a(x)) # (B, T, D)

b = self.W_b(x) # (B, T, D)

行列形式の状態更新

s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t

k = self.W_k(x)

v = self.W_v(x)

状態:(B, H, D/H, D/H) 行列!

for t in range(T):

state = torch.diag(a[:, t]) @ state + \

k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)

return output, state

Transformerとの性能比較

モデルサイズ別性能(Pile datasetパープレキシティ、低いほど良い):

| モデルサイズ | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |

|-------------|-------------|--------|--------|--------|

| 169M | 17.2 | 18.1 | 17.4 | 17.0 |

| 430M | 13.8 | 14.5 | 13.9 | 13.6 |

| 1.5B | 11.2 | 11.8 | 11.3 | 11.0 |

| 3B | 9.8 | 10.3 | 9.9 | 9.6 |

| 7B | 8.5 | 9.0 | 8.6 | 8.3 |

| 14B | 7.8 | 8.2 | 7.9 | 7.6 |

推論効率(トークン/秒、A100 GPU):

| シーケンス長 | Transformer | RWKV-7 |

|-------------|-------------|---------|

| 1K | 1000 | 1200 |

| 4K | 800 | 1200 |

| 16K | 300 | 1200 |

| 64K | 50 | 1200 |

| 128K+ | OOM | 1200 |

実践:RWKVを使う

HuggingFaceでの使用

from transformers import AutoModelForCausalLM, AutoTokenizer

RWKV-7モデルのロード

model = AutoModelForCausalLM.from_pretrained(

"RWKV/rwkv-7-world-3b",

trust_remote_code=True,

torch_dtype=torch.bfloat16,

device_map="auto"

)

tokenizer = AutoTokenizer.from_pretrained(

"RWKV/rwkv-7-world-3b",

trust_remote_code=True

)

テキスト生成

prompt = "KubernetesでPodオートスケーリングを実装するには"

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(

**inputs,

max_new_tokens=200,

temperature=0.7,

top_p=0.9

)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

RWKV Runnerでローカル実行

RWKV Runnerのインストール(GUIツール)

git clone https://github.com/josStorer/RWKV-Runner

cd RWKV-Runner

またはChatRWKV(CLI)

pip install rwkv

Pythonから直接使用

python3 << 'EOF'

from rwkv.model import RWKV

from rwkv.utils import PIPELINE

model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')

pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

result = pipeline.generate(

"韓国のAI産業の現状について説明してください。",

token_count=200,

temperature=0.8

)

print(result)

EOF

RWKV vs Mamba vs Transformer

| 特性 | Transformer | Mamba | RWKV-7 |

|---------------------|----------------|----------------|---------------|

| 学習計算量 | O(N²) | O(N) | O(N) |

| 推論計算量 | O(N) per token | O(1) per token | O(1) per token|

| メモリ | O(N) KV Cache | O(1) 状態 | O(1) 状態 |

| 並列学習 | 可能 | 可能 (scan) | 可能 (WKV) |

| 長距離依存 | 強い | 良好 | 良好 |

| In-Context Learning | 強い | 良好 | v7で強化 |

| 実装の複雑さ | 低い | 中程度 (CUDA) | 中程度 (CUDA) |

| コミュニティ | 巨大 | 成長中 | 活発 |

限界と今後の展望

現在の限界

1. **複雑な検索タスク**: TransformerのFull Attentionと比較して、特定パターン検索で劣る場合がある

2. **CUDAカーネル依存**: 最適な性能を得るにはカスタムCUDAカーネルが必要

3. **エコシステム**: Transformerに比べてツールやライブラリが不足

今後の方向性

- **ハイブリッドアーキテクチャ**: RWKVに少量のAttentionを組み合わせる

- **ハードウェア最適化**: Groq、Cerebrasなど新しいチップへの最適化

- **マルチモーダル**: ビジョン、オーディオなど他のモダリティへの拡張

クイズ

R: Receptance(受容ゲート)、W: Weight(時間減衰)、K: Key、V: Valueです。

学習時はTransformerのように並列(parallel)モードでシーケンス全体を一度に処理し、推論時はRNNのように逐次(sequential)モードでトークンあたりO(1)コストで動作します。

固定された減衰率の代わりに、入力データに応じて動的に減衰率が変化するメカニズムです。Mambaの選択的(Selective)メカニズムと類似したアイデアです。

State Transition Matrixを導入し、状態遷移をスカラーではなく行列で表現します。これにより、より豊かな状態更新と強化されたIn-Context Learningが可能になります。

シーケンス長に関係なく推論コストが一定(O(1) per token)です。128K以上のトークンでもTransformerのようにOOMが発生しません。

現在のトークンと前のトークンを加重平均で混合し、位置エンコーディングなしでもシーケンス内の位置情報を伝達します。

TransformerのFull Self-Attentionと比較して、複雑なパターン検索や正確な情報検索タスクで性能が劣る場合があります。

まとめ

RWKVは「RNNは死んだ」という定説を覆す革新的アーキテクチャです。Transformerの並列学習効率とRNNの推論効率を融合し、特に長いシーケンス処理やエッジデバイスへの展開で強みを発揮します。v7 GooseでState Transition Matrixが導入され、Transformerとの性能差がほぼ解消されました。

参考資料

- [RWKV論文 (arXiv: 2305.13048)](https://arxiv.org/abs/2305.13048)

- [RWKV-7 Goose Wiki](https://wiki.rwkv.com/)

- [RWKV GitHub](https://github.com/BlinkDL/RWKV-LM)

현재 단락 (1/217)

- **タイトル**: RWKV: Reinventing RNNs for the Transformer Era

작성 글자: 0원문 글자: 7,310작성 단락: 0/217