Skip to content
Published on

RWKV: Transformer時代のRNN再発明 — v4からv7 Gooseまで

Authors
  • Name
    Twitter
RWKV Architecture

論文概要

  • タイトル: RWKV: Reinventing RNNs for the Transformer Era
  • 著者: Bo Peng 他 (RWKV Foundation)
  • 初回発表: 2023年5月 (arXiv: 2305.13048)、EMNLP 2023
  • 最新バージョン: RWKV-7 "Goose"(2025年3月発表)
  • コード: github.com/BlinkDL/RWKV-LM

動機:Transformer vs RNNのジレンマ

現代のLLMはほぼすべてTransformerベースですが、根本的な限界があります:

  • O(N²) Self-Attention: シーケンス長に対して二次計算量
  • KV Cacheの爆発: 推論時にメモリがシーケンス長に比例
  • 長コンテキストのコスト: 128K以上のコンテキストでコスト・メモリが急増

一方、伝統的なRNNは:

  • O(N) 計算量: 線形時間・メモリ
  • 固定状態: 一定の推論コスト
  • しかし: 学習が並列化できず遅い、長距離依存の捕捉が弱い

RWKVの問い: 「Transformerのように並列学習しながら、RNNのように効率的に推論することはできないか?」

RWKVコアアーキテクチャ

RWKVという名前自体がアーキテクチャの核心要素を示しています:

  • R: Receptance(受容ゲート)— 過去の情報をどの程度受け入れるかを決定
  • W: Weight(時間減衰)— 過去の情報の減衰率
  • K: Key — 現在の入力のキー
  • V: Value — 現在の入力の値

WKV(Weighted Key-Value)メカニズム

RWKVの核心であるWKV演算は以下の通りです:

import torch

def wkv_vanilla(w, u, k, v):
    """
    RWKVのWKVメカニズム(純粋Python実装)
    w: time decay(負値、絶対値が大きいほど速く減衰)
    u: bonus(現在トークンのボーナス)
    k: key
    v: value
    """
    T, C = k.shape
    output = torch.zeros_like(v)

    for c in range(C):
        # チャネルごとに独立処理(O(T) per channel)
        a = 0.0  # 累積分子
        b = 0.0  # 累積分母
        p = -1e30  # 最大値(数値安定性)

        for t in range(T):
            # 現在トークンの寄与
            e1 = torch.exp(torch.clamp(u[c] + k[t, c] - p, max=30))
            e2 = torch.exp(torch.clamp(w[c] + p - p, max=30))  # 前の累積の減衰

            # WKV計算
            wkv = (e1 * v[t, c] + e2 * a) / (e1 + e2 * b)
            output[t, c] = wkv

            # 状態更新(RNN方式)
            new_p = max(w[c] + p, k[t, c])
            e1 = torch.exp(k[t, c] - new_p)
            e2 = torch.exp(w[c] + p - new_p)

            a = e2 * a + e1 * v[t, c]
            b = e2 * b + e1
            p = new_p

    return output

デュアルモード:Transformerモード vs RNNモード

# 学習時:Transformerモード(並列)
# シーケンス全体を一度に処理、O(T)計算量
def rwkv_parallel(x, w, u, k_proj, v_proj, r_proj):
    T, C = x.shape
    k = k_proj(x)  # (T, C)
    v = v_proj(x)  # (T, C)
    r = torch.sigmoid(r_proj(x))  # (T, C) 受容ゲート

    # 並列WKV演算(CUDAカーネル)
    wkv = parallel_wkv_cuda(w, u, k, v)

    return r * wkv  # 受容ゲート適用


# 推論時:RNNモード(逐次、一定メモリ)
def rwkv_sequential(x_t, state, w, u, k_proj, v_proj, r_proj):
    """1トークンずつ処理、O(1)計算量"""
    k = k_proj(x_t)
    v = v_proj(x_t)
    r = torch.sigmoid(r_proj(x_t))

    # state = (a, b, p) — 固定サイズ!
    a, b, p = state

    e1 = torch.exp(u + k - p)
    e2 = torch.exp(w + p - p)

    wkv = (e1 * v + e2 * a) / (e1 + e2 * b)
    output = r * wkv

    # 状態更新
    new_p = torch.maximum(w + p, k)
    e1 = torch.exp(k - new_p)
    e2 = torch.exp(w + p - new_p)
    new_a = e2 * a + e1 * v
    new_b = e2 * b + e1
    new_state = (new_a, new_b, new_p)

    return output, new_state

RWKVブロック構造

class RWKVBlock:
    """
    RWKVの基本ブロック構造

    Transformer Blockとの比較:
    Transformer: LayerNorm → Attention → Add → LayerNorm → FFN → Add
    RWKV:        LayerNorm → TimeMix   → Add → LayerNorm → ChannelMix → Add
    """

    def __init__(self, dim, layer_id):
        # Time Mixing(Attentionの代替)
        self.time_mix = TimeMixing(dim, layer_id)
        # Channel Mixing(FFNの代替)
        self.channel_mix = ChannelMixing(dim, layer_id)
        self.ln1 = LayerNorm(dim)
        self.ln2 = LayerNorm(dim)

    def forward(self, x, state):
        # Time Mixing(過去のトークン情報を混合)
        dx, state = self.time_mix(self.ln1(x), state)
        x = x + dx

        # Channel Mixing(チャネル間の情報を混合)
        dx = self.channel_mix(self.ln2(x))
        x = x + dx

        return x, state


class TimeMixing:
    """
    Token Shift + WKV
    現在のトークンと前のトークンを線形補間して使用
    """

    def __init__(self, dim, layer_id):
        self.mix_r = nn.Parameter(torch.ones(dim))  # 補間比率
        self.mix_k = nn.Parameter(torch.ones(dim))
        self.mix_v = nn.Parameter(torch.ones(dim))

    def forward(self, x, state):
        # Token Shift:現在と前のトークンの加重平均
        x_prev = state.shift  # 前のトークン
        xr = x * self.mix_r + x_prev * (1 - self.mix_r)
        xk = x * self.mix_k + x_prev * (1 - self.mix_k)
        xv = x * self.mix_v + x_prev * (1 - self.mix_v)

        r = torch.sigmoid(self.W_r(xr))
        k = self.W_k(xk)
        v = self.W_v(xv)

        wkv = compute_wkv(k, v, state)
        return r * wkv, new_state

バージョンごとの進化

RWKV-4 (Eagle)

  • 基本的なWKVメカニズムを導入
  • Token Shiftで位置エンコーディングを代替
  • 最大14Bパラメータ

RWKV-5 (Eagle)

# RWKV-5: Multi-headed State
# 複数の独立した状態を保持し表現力を向上
class RWKV5_TimeMix:
    def __init__(self, dim, n_heads=8):
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        # 各ヘッドが独立したdecay rateを保持
        self.time_decay = nn.Parameter(torch.randn(n_heads, self.head_dim))

RWKV-6 (Finch)

# RWKV-6: Data-dependent time decay
# 入力に応じて減衰率が動的に変化(Mambaの選択的メカニズムと類似)
class RWKV6_TimeMix:
    def forward(self, x, state):
        # 固定decayではなく、入力依存のdecay
        time_decay = self.W_decay(x)  # 入力ごとに異なる減衰率!
        time_decay = torch.exp(-torch.exp(time_decay))

        # LoRAスタイルのdecay変調
        decay_lora = self.decay_lora_a(x)
        decay_lora = torch.tanh(decay_lora) @ self.decay_lora_b
        time_decay = time_decay + decay_lora

RWKV-7 (Goose) — 2025年3月

# RWKV-7: State Transition Matrix
# 状態遷移を行列で表現し、より豊かな状態更新を実現
class RWKV7_TimeMix:
    """
    RWKV-7の核心的革新:
    1. State Transition Matrix:スカラーではなく行列で状態遷移
    2. In-Context Learningの強化:動的に学習ルールを調整
    3. Improved Token Mixing:より精巧なトークン間情報フロー
    """

    def forward(self, x, state):
        # 状態遷移行列の計算
        a = torch.sigmoid(self.W_a(x))  # (B, T, D)
        b = self.W_b(x)                 # (B, T, D)

        # 行列形式の状態更新
        # s_{t+1} = diag(a_t) @ s_t + k_t^T @ v_t
        k = self.W_k(x)
        v = self.W_v(x)

        # 状態:(B, H, D/H, D/H) 行列!
        for t in range(T):
            state = torch.diag(a[:, t]) @ state + \
                    k[:, t].unsqueeze(-1) @ v[:, t].unsqueeze(-2)

        return output, state

Transformerとの性能比較

モデルサイズ別性能(Pile datasetパープレキシティ、低いほど良い):

| モデルサイズ | Transformer | RWKV-4 | RWKV-6 | RWKV-7 |
|-------------|-------------|--------|--------|--------|
| 169M        | 17.2        | 18.1   | 17.4   | 17.0   |
| 430M        | 13.8        | 14.5   | 13.9   | 13.6   |
| 1.5B        | 11.2        | 11.8   | 11.3   | 11.0   |
| 3B          | 9.8         | 10.3   | 9.9    | 9.6    |
| 7B          | 8.5         | 9.0    | 8.6    | 8.3    |
| 14B         | 7.8         | 8.2    | 7.9    | 7.6    |

推論効率(トークン/秒、A100 GPU):

| シーケンス長 | Transformer | RWKV-7  |
|-------------|-------------|---------|
| 1K          | 1000        | 1200    |
| 4K          | 800         | 1200    |
| 16K         | 300         | 1200    |
| 64K         | 50          | 1200    |
| 128K+       | OOM         | 1200    |

実践:RWKVを使う

HuggingFaceでの使用

from transformers import AutoModelForCausalLM, AutoTokenizer

# RWKV-7モデルのロード
model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
    "RWKV/rwkv-7-world-3b",
    trust_remote_code=True
)

# テキスト生成
prompt = "KubernetesでPodオートスケーリングを実装するには"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens=200,
    temperature=0.7,
    top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

RWKV Runnerでローカル実行

# RWKV Runnerのインストール(GUIツール)
git clone https://github.com/josStorer/RWKV-Runner
cd RWKV-Runner

# またはChatRWKV(CLI)
pip install rwkv

# Pythonから直接使用
python3 << 'EOF'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

model = RWKV(model='/path/to/RWKV-7-World-3B.pth', strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

result = pipeline.generate(
    "韓国のAI産業の現状について説明してください。",
    token_count=200,
    temperature=0.8
)
print(result)
EOF

RWKV vs Mamba vs Transformer

| 特性                | Transformer    | Mamba          | RWKV-7        |
|---------------------|----------------|----------------|---------------|
| 学習計算量           | O(N²)         | O(N)           | O(N)          |
| 推論計算量           | O(N) per token | O(1) per token | O(1) per token|
| メモリ              | O(N) KV Cache  | O(1) 状態     | O(1) 状態     |
| 並列学習            | 可能           | 可能 (scan)    | 可能 (WKV)    |
| 長距離依存          | 強い           | 良好           | 良好          |
| In-Context Learning | 強い           | 良好           | v7で強化      |
| 実装の複雑さ         | 低い           | 中程度 (CUDA)  | 中程度 (CUDA) |
| コミュニティ         | 巨大           | 成長中         | 活発          |

限界と今後の展望

現在の限界

  1. 複雑な検索タスク: TransformerのFull Attentionと比較して、特定パターン検索で劣る場合がある
  2. CUDAカーネル依存: 最適な性能を得るにはカスタムCUDAカーネルが必要
  3. エコシステム: Transformerに比べてツールやライブラリが不足

今後の方向性

  • ハイブリッドアーキテクチャ: RWKVに少量のAttentionを組み合わせる
  • ハードウェア最適化: Groq、Cerebrasなど新しいチップへの最適化
  • マルチモーダル: ビジョン、オーディオなど他のモダリティへの拡張

クイズ

Q1. RWKVのR、W、K、Vはそれぞれ何を意味しますか?

R: Receptance(受容ゲート)、W: Weight(時間減衰)、K: Key、V: Valueです。

Q2. RWKVは学習時と推論時にそれぞれどのモードで動作しますか?

学習時はTransformerのように並列(parallel)モードでシーケンス全体を一度に処理し、推論時はRNNのように逐次(sequential)モードでトークンあたりO(1)コストで動作します。

Q3. RWKV-6で導入されたData-dependent time decayとは?

固定された減衰率の代わりに、入力データに応じて動的に減衰率が変化するメカニズムです。Mambaの選択的(Selective)メカニズムと類似したアイデアです。

Q4. RWKV-7 Gooseの核心的革新は?

State Transition Matrixを導入し、状態遷移をスカラーではなく行列で表現します。これにより、より豊かな状態更新と強化されたIn-Context Learningが可能になります。

Q5. Transformerに対するRWKVの最大の利点は?

シーケンス長に関係なく推論コストが一定(O(1) per token)です。128K以上のトークンでもTransformerのようにOOMが発生しません。

Q6. Token Shiftメカニズムの役割は?

現在のトークンと前のトークンを加重平均で混合し、位置エンコーディングなしでもシーケンス内の位置情報を伝達します。

Q7. RWKVの現在の限界の1つは?

TransformerのFull Self-Attentionと比較して、複雑なパターン検索や正確な情報検索タスクで性能が劣る場合があります。

まとめ

RWKVは「RNNは死んだ」という定説を覆す革新的アーキテクチャです。Transformerの並列学習効率とRNNの推論効率を融合し、特に長いシーケンス処理やエッジデバイスへの展開で強みを発揮します。v7 GooseでState Transition Matrixが導入され、Transformerとの性能差がほぼ解消されました。

参考資料