Skip to content
Published on

RWKV-7 "Goose" アーキテクチャ分析 — Transformerを超える線形時間モデル

Authors
  • Name
    Twitter
RWKV-7 Goose Architecture

はじめに

Transformerは現代LLMの基盤ですが、O(n²)アテンションコストという根本的な限界があります。RWKV-7 "Goose"はこの問題を解決する新しいシーケンスモデリングアーキテクチャで、定数メモリ使用量トークンあたり一定の推論時間を達成しながら、Transformerに匹敵する性能を実現しています。

2025年7月にICMLで発表されたこの論文("RWKV-7 Goose with Expressive Dynamic State Evolution")を深掘り分析します。

RWKVシリーズの進化

RWKV-4からRWKV-7まで

RWKV(Receptance Weighted Key Value)はRNNとTransformerの長所を融合したアーキテクチャです:

  • RWKV-4 (2023): 基本的なWKVメカニズムを導入。Linear Attentionの変形
  • RWKV-5 "Eagle" (2024): Multi-headed WKV、性能向上
  • RWKV-6 "Finch" (2024): Data-dependent decay、LoRA統合
  • RWKV-7 "Goose" (2025): Dynamic State Evolution、TC0限界突破

Attention vs Linear Attention vs RWKV-7

# Standard Attention: O(n^2) 時間、O(n) メモリ
# output = softmax(Q @ K^T / sqrt(d)) @ V

# Linear Attention: O(n) 時間、O(d^2) メモリ(定数)
# output = (Q @ (K^T @ V)) -- 行列積の順序変更

# RWKV-7: O(n) 時間、O(d^2) メモリ(定数)
# + Dynamic State Evolutionで表現力を最大化

核心メカニズム:Dynamic State Evolution

既存Linear Attentionの限界

既存のLinear Attention(RWKV-4〜6を含む)は**TC0(Threshold Circuit Class 0)**に属します。これは理論的に特定の問題が解けないことを意味します:

# TC0限界の例:状態追跡(state tracking)問題
# 入力:「Aは部屋1にいる。Aは部屋2に移動する。Bは部屋1にいる。Aはどこにいる?」
# TC0モデルはこのような状態変化の追跡が理論的に不可能

RWKV-7のDynamic State Evolution

RWKV-7は状態遷移行列自体を入力に応じて動的に変更します:

import torch
import torch.nn as nn

class RWKV7_TimeMix(nn.Module):
    """RWKV-7の核心:Dynamic State Evolution"""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # 入力依存パラメータ生成器
        self.W_r = nn.Linear(d_model, d_model)  # Receptance
        self.W_k = nn.Linear(d_model, d_model)  # Key
        self.W_v = nn.Linear(d_model, d_model)  # Value
        self.W_a = nn.Linear(d_model, d_model)  # State transition
        self.W_g = nn.Linear(d_model, d_model)  # Gate

        # Dynamic decayパラメータ
        self.w_decay = nn.Parameter(torch.randn(n_heads, self.d_head))

        # State evolution matrix学習パラメータ
        self.A_log = nn.Parameter(torch.randn(n_heads, self.d_head, self.d_head))

    def forward(self, x, state=None):
        B, T, C = x.shape
        H = self.n_heads
        D = self.d_head

        r = self.W_r(x).view(B, T, H, D)
        k = self.W_k(x).view(B, T, H, D)
        v = self.W_v(x).view(B, T, H, D)
        a = torch.sigmoid(self.W_a(x).view(B, T, H, D))
        g = torch.sigmoid(self.W_g(x).view(B, T, H, D))

        if state is None:
            state = torch.zeros(B, H, D, D, device=x.device)

        outputs = []
        for t in range(T):
            # Dynamic State Evolution:核心的革新
            # 従来:state = decay * state + k^T @ v(固定decay)
            # RWKV-7:state = A(x_t) @ state + k^T @ v(動的遷移行列)

            # 入力依存の遷移行列
            A_t = self._compute_transition(a[:, t], state)

            # State更新
            kv = torch.einsum('bhd,bhe->bhde', k[:, t], v[:, t])
            state = A_t * state + kv

            # Output計算
            out = torch.einsum('bhd,bhde->bhe', r[:, t], state)
            out = out * g[:, t]
            outputs.append(out)

        output = torch.stack(outputs, dim=1).view(B, T, C)
        return output, state

    def _compute_transition(self, a_t, state):
        """入力に応じて遷移行列を動的に生成"""
        # a_t: [B, H, D] - 現在の入力から導出された遷移パラメータ
        # これがRWKV-7がTC0を超える鍵
        decay = torch.exp(-torch.exp(self.w_decay))
        A = decay.unsqueeze(-1) * torch.eye(
            self.d_head, device=a_t.device
        ).unsqueeze(0)

        # 入力依存の補正
        A = A + torch.einsum('bhd,bhe->bhde', a_t, a_t) * \
            torch.exp(self.A_log).unsqueeze(0)

        return A

TC0を超える理由

# RWKV-6(従来):state_{t+1} = diag(w) * state_t + k_t * v_t^T
# -> 対角行列積:各次元が独立にdecay
# -> TC0範囲内:有限深度では状態追跡不可

# RWKV-7(新規):state_{t+1} = A(x_t) * state_t + k_t * v_t^T
# -> A(x_t)は入力依存の遷移行列(非対角)
# -> 次元間の相互作用が可能:状態追跡問題を解決可能
# -> TC0を超える表現力

アーキテクチャ全体構造

RWKV-7ブロック構成

class RWKV7Block(nn.Module):
    """RWKV-7基本ブロック"""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.time_mix = RWKV7_TimeMix(d_model, n_heads)
        self.channel_mix = RWKV7_ChannelMix(d_model)

    def forward(self, x, state=None):
        # Time Mixing(トークン間の相互作用)
        h, state = self.time_mix(self.ln1(x), state)
        x = x + h

        # Channel Mixing(FFNの役割)
        x = x + self.channel_mix(self.ln2(x))

        return x, state


class RWKV7_ChannelMix(nn.Module):
    """チャネルミキシング(SwiGLU変形)"""

    def __init__(self, d_model, expand=3.5):
        super().__init__()
        hidden = int(d_model * expand)
        self.W_key = nn.Linear(d_model, hidden)
        self.W_value = nn.Linear(hidden, d_model)
        self.W_gate = nn.Linear(d_model, hidden)

    def forward(self, x):
        k = self.W_key(x)
        v = self.W_value(torch.relu(k) ** 2)  # squared ReLU
        g = torch.sigmoid(self.W_gate(x))
        return v * g

推論モード:RNNとして動作

class RWKV7_Inference:
    """推論時はRNNモードで動作:トークンあたりO(d^2)演算"""

    def __init__(self, model):
        self.model = model
        self.states = None  # [n_layers, n_heads, d_head, d_head]

    def generate_token(self, token_id):
        x = self.model.embed(token_id)

        for i, block in enumerate(self.model.blocks):
            x, self.states[i] = block(x.unsqueeze(0).unsqueeze(0),
                                       self.states[i])
            x = x.squeeze(0).squeeze(0)

        logits = self.model.head(self.model.ln_out(x))
        return logits

    # メモリ使用量:シーケンス長と無関係に定数!
    # Transformer:KV Cacheが O(n * d) メモリ使用
    # RWKV-7:Stateが O(d^2) 固定メモリ

ベンチマーク結果

3Bモデル比較(論文Table 1ベース)

モデルパラメータMMLUHellaSwagARC-E推論メモリ
LLaMA-3.2 3B3.2B63.477.279.5O(n) KV Cache
Mamba-2 2.7B2.7B58.173.874.2O(1) 定数
RWKV-6 3B3.0B58.974.575.1O(1) 定数
RWKV-7 3B3.0B61.276.878.3O(1) 定数

多言語ベンチマーク(RWKV-7の強み)

# RWKV-7は特に多言語で優れた性能を発揮
# 100以上の言語で学習、英語以外の言語でSOTA級

# 韓国語・日本語・中国語のベンチマークで
# 同サイズのTransformerを上回る性能
# -> 多言語トークン化の効率性 + 長コンテキスト処理

推論効率

# シーケンス長別の推論コスト比較
#
# Sequence Length | Transformer | RWKV-7
# 1K             | 1x          | 0.8x
# 4K             | 4x          | 0.8x
# 16K            | 16x         | 0.8x
# 64K            | 64x         | 0.8x
# 1M             | OOM         | 0.8x
#
# RWKV-7はシーケンス長に関係なく一定のコスト!

実践:RWKV-7を使う

Hugging Faceからモデルをロード

# pip install rwkv torch transformers

from rwkv.model import RWKV
from rwkv.utils import PIPELINE

# モデルロード
model = RWKV(
    model='/path/to/RWKV-7-World-3B',
    strategy='cuda fp16'  # GPU FP16
)

pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

# テキスト生成
context = "Kubernetes VPAの利点は"
result = pipeline.generate(
    context,
    token_count=200,
    temperature=0.8,
    top_p=0.9
)
print(result)

vLLMでサービング

# RWKV-7 vLLMサービング(実験的サポート)
pip install vllm>=0.6.0

python -m vllm.entrypoints.openai.api_server \
  --model RWKV/rwkv-7-world-3b \
  --tokenizer RWKV/rwkv-7-world-3b \
  --dtype float16 \
  --max-model-len 32768

Ollamaでローカル実行

# RWKV-7 GGUFモデルをダウンロード後
ollama create rwkv7 -f Modelfile

# Modelfileの例:
# FROM rwkv-7-world-3b-q4_k_m.gguf
# PARAMETER temperature 0.7
# PARAMETER num_ctx 32768

ollama run rwkv7

RWKV-7 vs Mamba-2 vs Transformer

アーキテクチャ比較

特性TransformerMamba-2RWKV-7
時間計算量O(n²)O(n)O(n)
推論メモリO(n·d) KV CacheO(d²) 定数O(d²) 定数
並列学習完全並列チャンク並列完全並列
表現力TC0以上TC0TC0以上
ハードウェア最適化成熟発展中発展中

どのような場合にRWKV-7を選ぶべきか?

# RWKV-7が適しているシナリオ:
# 1. 非常に長いコンテキスト処理(100K以上のトークン)
# 2. エッジデバイス推論(定数メモリ)
# 3. 多言語サービス(韓国語・日本語・中国語など)
# 4. リアルタイムストリーミング(トークンあたり一定時間)

# Transformerがまだ優れているシナリオ:
# 1. 最大性能が必要な場合(特に英語)
# 2. 短いコンテキスト(1K以下)
# 3. 既存のエコシステム・ツールの活用

まとめ

RWKV-7 Gooseは、定数メモリ+線形時間という効率性を維持しながらTransformer水準の性能を達成した革新的アーキテクチャです。Dynamic State Evolutionにより既存のLinear Attentionの理論的限界(TC0)を突破し、特に多言語と長コンテキストで強みを発揮します。


クイズ(7問)

Q1. RWKV-7の推論時メモリ計算量は? O(d²) 定数 — シーケンス長とは無関係

Q2. TC0(Threshold Circuit Class 0)の限界とは? 有限深度のLinear Attentionでは状態追跡(state tracking)のような問題を理論的に解けない

Q3. RWKV-7がTC0を超える核心メカニズムは? Dynamic State Evolution — 入力依存の遷移行列A(x_t)で次元間の相互作用が可能

Q4. Transformerの推論メモリがO(n·d)である理由は? KV Cacheがシーケンス長に比例して増加するため

Q5. RWKV-7が学習時に並列処理可能な理由は? WKV演算をチャンク単位で並列化できる構造

Q6. RWKV-7が多言語で特に強い理由は? 100以上の言語で学習されたWorldトークナイザー+効率的な長コンテキスト処理

Q7. RWKV-7のChannel Mixで使用される活性化関数は? Squared ReLU(ReLUの二乗)