Skip to content
Published on

RWKVアーキテクチャ徹底解説:Transformerに匹敵する線形アテンションRNN

Authors
  • Name
    Twitter
RWKV Architecture

はじめに

2017年にVaswaniらがTransformerを発表して以来、自己アテンションメカニズムはシーケンスモデリングにおける支配的なパラダイムとなりました。しかし、Transformerには根本的な負担が伴います。シーケンス長に対する二次的な時間・メモリ計算量 O(L^2 * d) は、長いシーケンスでの推論を高コストにし、増大し続けるKVキャッシュを必要とします。リカレントニューラルネットワーク(RNN)はステップあたり O(d) の定数メモリ推論を提供しますが、LSTMのような古典的なRNNは勾配消失問題、逐次的な学習ボトルネック、数億パラメータ以上へのスケーリングの困難さに悩まされてきました。

RWKV(「ルワクブ」と発音)-- Receptance Weighted Key Value -- は、Bo Peng(BlinkDL)によって設計された新しいアーキテクチャであり、この溝を埋めるものです。アテンションを線形リカレンスとして再定式化することで、RWKVは Transformerレベルの言語モデリング品質RNNレベルの推論効率 を両立します。学習の総計算量は O(T*d)(シーケンス長に対して線形)であり、推論はKVキャッシュなしでステップあたり O(d) です。このアーキテクチャは14Bパラメータ(RWKV-4 Eagle)までスケーリングされ、これは史上最大の密なRNNとなっています。ベンチマークでは、同規模のTransformerと標準的なNLPタスクで同等の性能を示しています。

RWKV論文はEMNLP 2023 Findings(arXiv: 2305.13048)で発表され、アーキテクチャはバージョン5(Eagle)、6(Finch)、7(Goose)と進化を続けており、各イテレーションでより表現力豊かな状態進化メカニズムが導入されています。

本記事では、WKVメカニズムの数学的基盤、アーキテクチャの完全なウォークスルー、Transformer・Mamba・LSTMとの比較、実践的な学習・デプロイガイド、および既知の制限事項について解説します。

RWKVアーキテクチャ概要

高レベル構造

RWKVはN個の残差ブロックを積み重ねた構造で、各ブロックには2つのサブレイヤーが含まれます:Time Mixing(アテンションに相当)と Channel Mixing(FFNに相当)です。Transformerとは異なり、位置エンコーディングは存在しません。時間的な情報はリカレントなWKVメカニズムと学習された時間減衰パラメータによって暗黙的に捉えられます。

+--------------------------------------------------+
|                RWKV Block (x N)                  |
|                                                  |
|  +--------------------------------------------+  |
|  |           Layer Norm                       |  |
|  +---------------------+----------------------+  |
|                        |                         |
|  +---------------------v----------------------+  |
|  |       Time Mixing (WKV Attention)          |  |
|  |                                            |  |
|  |   x_t --+-- R (Receptance) -- sigmoid(r)  |  |
|  |         +-- K (Key)                        |  |
|  |         +-- V (Value)                      |  |
|  |         +-- W (Time Decay, learned)        |  |
|  |                                            |  |
|  |   wkv_t = weighted_sum(K, V, W, U)        |  |
|  |   out_t = sigmoid(r_t) * wkv_t            |  |
|  +---------------------+----------------------+  |
|                        | + residual              |
|  +---------------------v----------------------+  |
|  |           Layer Norm                       |  |
|  +---------------------+----------------------+  |
|                        |                         |
|  +---------------------v----------------------+  |
|  |       Channel Mixing (FFN)                 |  |
|  |                                            |  |
|  |   x_t --+-- R (Receptance) -- sigmoid(r)  |  |
|  |         +-- K (Key) -- squared_relu(k)     |  |
|  |                                            |  |
|  |   out_t = sigmoid(r_t) * (W_v * relu2(k)) |  |
|  +---------------------+----------------------+  |
|                        | + residual              |
+--------------------------------------------------+

トークンシフト:秘密の要素

各サブレイヤーでR、K、Vを計算する前に、RWKVはトークンシフト(タイムシフトまたは線形補間とも呼ばれる)を適用します。現在のトークン埋め込み x_t のみを使用する代わりに、RWKVはそれを前のトークンと混合します:

x'_t = mu * x_t + (1 - mu) * x_{t-1}

ここで mu はチャネルごとの学習可能な補間重みです。この単純な操作により、キーとバリューの計算前にバイグラムレベルの情報にアクセスでき、完全なアテンションの不在を補う安価な形のローカルコンテキストを提供します。

WKVメカニズム:数学的定式化

コア計算

WKV(Weighted Key Value)演算子はRWKVの核心です。softmaxアテンションを指数関数的に減衰する重み付き和で置き換えます。位置tにおけるWKV出力は以下の通りです:

wkv_t = (sum_{i=1}^{t-1} exp(-(t-1-i)*w + k_i) * v_i  +  exp(u + k_t) * v_t)
        / (sum_{i=1}^{t-1} exp(-(t-1-i)*w + k_i)  +  exp(u + k_t))

ここで:

  • w時間減衰 パラメータ(チャネルごと、学習される)、w = exp(decay) により常に正
  • u は現在のトークンに追加の重みを与える ボーナス パラメータ
  • k_i, v_i は位置iにおけるキーとバリュー

分子は過去の全バリューの重み付き和を蓄積し、各過去トークンの重みは距離に応じて指数関数的に減衰します。現在のトークンには通常の減衰の代わりに特別なボーナス u が適用されます。

リカレント定式化

RNNモード推論を可能にする重要な洞察は、WKV計算をリカレンスとして表現できることです。累積分子 alpha_t と累積分母 beta_t を定義します:

alpha_t = exp(-w) * alpha_{t-1} + exp(k_t) * v_t
beta_t  = exp(-w) * beta_{t-1}  + exp(k_t)

すると、各ステップの出力は以下のようになります:

wkv_t = (exp(-w) * alpha_{t-1} + exp(u + k_t) * v_t)
      / (exp(-w) * beta_{t-1}  + exp(u + k_t))

これにはステップあたり O(d) の計算と O(d) のメモリのみが必要で、累積状態alphaとbetaはそれぞれd次元ベクトルです。

対数空間計算による数値安定性

指数関数の直接計算はオーバーフローする可能性があります。RWKVは数値安定性を維持するために対数空間のトリックを使用します:

import torch

def rwkv_wkv_single_step(w, u, k, v, alpha_prev, beta_prev, log_max_prev):
    """
    Single-step WKV computation with numerical stability.

    Args:
        w: time decay (d,) -- positive values
        u: bonus parameter (d,)
        k: key vector (d,)
        v: value vector (d,)
        alpha_prev: running numerator state (d,)
        beta_prev: running denominator state (d,)
        log_max_prev: log of max exponent for stability (d,)

    Returns:
        wkv: output vector (d,)
        alpha_new, beta_new, log_max_new: updated states
    """
    # Compute log-space exponents
    log_ew = -w  # log(e^(-w))
    log_ek = k   # log(e^k)

    # For the past terms: e^(-w) * prev = e^(log_ew + log_max_prev)
    log_past = log_ew + log_max_prev
    # For the current term: e^(u + k)
    log_curr = u + k

    # Numerically stable log-sum-exp
    log_max_new = torch.max(log_past, log_curr)

    past_scale = torch.exp(log_past - log_max_new)
    curr_scale = torch.exp(log_curr - log_max_new)

    wkv = (past_scale * alpha_prev + curr_scale * v) / \
          (past_scale * beta_prev + curr_scale)

    # Update running states
    log_max_alpha = torch.max(log_past, log_ek)
    alpha_new = torch.exp(log_past - log_max_alpha) * alpha_prev + \
                torch.exp(log_ek - log_max_alpha) * v
    beta_new = torch.exp(log_past - log_max_alpha) * beta_prev + \
               torch.exp(log_ek - log_max_alpha)

    return wkv, alpha_new, beta_new, log_max_alpha

並列学習の定式化

学習時には、WKVをプレフィックスサム(スキャン)操作を使用してシーケンス全体にわたって並列に計算できます。exp(-w) がチャネルごとの固定減衰として機能するため、累積重みは効率的に計算可能な等比級数を形成します:

def rwkv_wkv_parallel(w, u, k, v):
    """
    Parallel WKV computation for training.

    Args:
        w: time decay (d,)
        u: bonus (d,)
        k: keys (T, d)
        v: values (T, d)

    Returns:
        wkv: output (T, d)
    """
    T, d = k.shape
    ew = torch.exp(-w)  # per-channel decay factor (d,)
    ek = torch.exp(k)   # (T, d)
    ekv = ek * v        # (T, d)

    # Conceptual O(T*d) scan -- in practice uses a custom CUDA kernel
    alpha = torch.zeros(d, device=k.device)
    beta = torch.zeros(d, device=k.device)
    wkv = torch.zeros(T, d, device=k.device)

    for t in range(T):
        # Current token with bonus
        euk = torch.exp(u + k[t])

        wkv[t] = (alpha + euk * v[t]) / (beta + euk)

        # Update running sums (without bonus)
        alpha = ew * alpha + ek[t] * v[t]
        beta = ew * beta + ek[t]

    return wkv

実際には、RWKVはこれらの操作を高効率なスキャンに融合するカスタムCUDAカーネルを使用し、GPU上で準線形的なスピードアップを実現しています。

二重性:Transformerモード vs RNNモード

RWKVの最もエレガントな側面の一つは、同じモデルが2つのモードで動作できることです:

+-----------------------------------------------------------+
|               RWKV Dual Operation Modes                   |
+---------------------------+-------------------------------+
|   Training Mode           |   Inference Mode              |
|   (Transformer-like)      |   (RNN-like)                  |
|                           |                               |
|   +---+---+---+---+      |   +---+                       |
|   |t=1|t=2|t=3|t=4|      |   |t=n|                       |
|   +-+-+-+-+-+-+-+-++      |   +-+-+                       |
|     |   |   |   |         |     |                         |
|     v   v   v   v         |     v                         |
|   +---------------+      |   +-------+   +---------+     |
|   | Parallel Scan |      |   | WKV   |-->| State   |     |
|   | (all at once) |      |   | step  |   | (a, b)  |     |
|   +-------+-------+      |   +---+---+   +---------+     |
|           |               |       |                       |
|           v               |       v                       |
|   +---------------+      |   +-----------+               |
|   |  O(Td) total  |      |   | O(d)/step |               |
|   |  parallelized |      |   | constant  |               |
|   +---------------+      |   +-----------+               |
+---------------------------+-------------------------------+
|  Same weights, same results, different compute pattern    |
+-----------------------------------------------------------+

学習時:スキャン定式化を使用してシーケンス全体を並列に処理します。これはリカレンスと数学的に等価ですが、GPUの並列性を活用できます。計算量:合計 O(Td)

推論時:1トークンずつ処理し、固定サイズの隠れ状態を更新します。KVキャッシュは不要です。計算量:トークンあたり O(d)、コンテキスト長に関係なくメモリは O(d)

この二重性がRWKVの根本的な利点です。妥協なく両方の世界のベストを得ることができます。

アーキテクチャ比較

RWKV vs Transformer vs Mamba vs LSTM

特徴RWKV (v4/v5/v6)TransformerMamba (S6)LSTM
学習計算量O(Td)O(T^2 d)O(Td)O(Td)
推論(ステップあたり)O(d)O(Td) KVキャッシュ付きO(d)O(d)
メモリ(推論)O(d) 定数O(T*d) 増加O(d) 定数O(d) 定数
並列学習可能はい(スキャン)はい(行列積)はい(スキャン)いいえ(逐次的)
KVキャッシュ必要いいえはいいいえいいえ
最大学習規模14B(Eagle)1.8T+(GPT-4クラス)8B(Mamba-2)約1B
長距離記憶良好(減衰型)優秀良好(選択的)不良
位置エンコーディングなし(暗黙的)必要なし(暗黙的)なし(暗黙的)
アテンションパターン線形減衰完全な二次選択的SSMゲート付きリカレンス
コンテキスト長無制限(理論上)固定ウィンドウ無制限(理論上)制限あり(勾配消失)
HuggingFaceサポートありありありあり
エコシステム成熟度成長中支配的成長中成熟(レガシー)
Needle-in-Haystack弱い強い中程度非常に弱い

主要なトレードオフ

RWKV vs Transformer:RWKVは推論効率で決定的に勝ります(定数メモリ、KVキャッシュ不要、線形生成時間)。Transformerはコンテキスト内の任意の位置からの精密な検索を必要とするタスク(needle-in-a-haystack)で勝ります。ほとんどの生成タスクでは、同等規模での品質差は小さいです。

RWKV vs Mamba:両者とも線形計算量と定数メモリ推論を実現します。Mambaは入力依存のSSMパラメータ(選択メカニズム)を使用し、RWKVは学習済み補間を持つ固定時間減衰を使用します。Mambaは強いコンテンツベースの選択を必要とするタスクでやや優れる傾向がありますが、RWKVはより成熟したエコシステムを持ち、より大きなスケールまで学習されています。RWKV-7 Gooseは動的状態進化によりこのギャップを大幅に縮めています。

RWKV vs LSTM:RWKVは厳密に上位互換です。並列学習が可能(LSTMは不可)、数十億パラメータまでスケール可能、はるかに良いパープレキシティを達成します。WKVメカニズムはゲート付きリカレンスのより表現力豊かな形態です。

RWKVバージョンの進化

v4からv7まで

# Version evolution summary
rwkv_versions = {
    "v4 (Pile)": {
        "year": 2023,
        "max_params": "14B",
        "key_feature": "Original WKV with fixed time decay",
        "state_size": "d per layer (scalar decay)"
    },
    "v5 (Eagle)": {
        "year": 2024,
        "max_params": "7.5B",
        "key_feature": "Multi-headed WKV, matrix-valued states",
        "state_size": "h * (d/h)^2 per layer"
    },
    "v6 (Finch)": {
        "year": 2024,
        "max_params": "7.5B",
        "key_feature": "Data-dependent time decay, LoRA-style mixing",
        "state_size": "h * (d/h)^2 per layer"
    },
    "v7 (Goose)": {
        "year": 2025,
        "max_params": "2.9B (scaling ongoing)",
        "key_feature": "Dynamic state evolution via generalized delta rule",
        "state_size": "h * (d/h)^2 per layer (dynamic)"
    }
}

RWKV-7 Gooseは特に重要なステップです。ベクトル値ゲーティングとインコンテキスト学習率を持つ一般化デルタルールを導入しています。これにより隠れ状態が入力内容に基づいて動的に進化でき、固定された線形アテンションのTC0表現力の制限を克服します。Pileベンチマークの3Bパラメータ規模では、RWKV-7はパープレキシティ9.6を達成し、Transformerの9.8やRWKV-6の9.9を上回っています。

推論パフォーマンスベンチマーク

トークン生成速度

以下のベンチマークは単一のNVIDIA A100 80GB GPU上で、3Bパラメータ規模のモデルを比較して収集されました:

Tokens/sec at various sequence lengths (3B params, A100 GPU):

Seq Length  | RWKV-6  | Transformer | Mamba-2  | Note
------------|---------|-------------|----------|------------------
512         |  2,400  |    2,200    |  2,500   | All comparable
2,048       |  2,350  |    1,800    |  2,450   | Transformer slowing
8,192       |  2,300  |      950    |  2,400   | KV cache pressure
32,768      |  2,250  |      280    |  2,350   | Transformer struggling
131,072     |  2,200  |      OOM    |  2,300   | Transformer OOM
524,288     |  2,100  |      OOM    |  2,200   | Both linear models OK

Memory Usage (inference state):

Model       | 1K ctx  | 8K ctx  | 128K ctx | 1M ctx
------------|---------|---------|----------|--------
RWKV-6 3B   | 6.2 GB  | 6.2 GB  |  6.2 GB  | 6.2 GB  (constant!)
Transformer | 6.5 GB  | 8.1 GB  | 32.4 GB  |  OOM
Mamba-2 3B  | 6.3 GB  | 6.3 GB  |  6.3 GB  | 6.3 GB  (constant)

重要なポイント:RWKVはコンテキスト長に関係なく定数メモリと準定数速度を維持する一方、Transformerは8Kトークンを超えると急速に劣化します。

パープレキシティ比較(The Pile)

Model               | Params | Pile Val PPL | LAMBADA | HellaSwag
--------------------|--------|--------------|---------|----------
RWKV-4              |  7B    |    8.28      |  67.2%  |  52.5%
RWKV-5 Eagle        |  7B    |    8.15      |  68.1%  |  53.2%
RWKV-6 Finch        |  7B    |    8.05      |  69.0%  |  54.1%
RWKV-7 Goose        |  2.9B  |    9.60      |  65.8%  |  51.0%
Pythia (Transformer) |  6.9B |    8.25      |  67.1%  |  52.0%
LLaMA-like          |  7B    |    7.95      |  73.0%  |  56.4%
Mamba               |  2.8B  |    9.80      |  64.9%  |  50.3%

同等規模では、RWKVはTransformerベースラインと競争力があります。ただし、広範なデータキュレーションを行った最良のTransformerモデル(LLaMAファミリーなど)はまだわずかに優位を保っています。

学習とファインチューニングの実践ガイド

環境セットアップ

# Clone the RWKV-LM repository
git clone https://github.com/BlinkDL/RWKV-LM.git
cd RWKV-LM/RWKV-v7

# Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install lightning deepspeed wandb ninja

# For custom CUDA kernel compilation
pip install triton

# Verify CUDA kernel builds
python -c "from rwkv.model import RWKV; print('RWKV loaded successfully')"

スクラッチからのフル学習

# train_config.py -- Example training configuration
import lightning as L
from rwkv.model import RWKV

# Model configuration
model_config = {
    "n_layer": 24,         # Number of RWKV blocks
    "n_embd": 2048,        # Embedding dimension
    "vocab_size": 65536,   # Vocabulary size
    "ctx_len": 4096,       # Context length for training
    "head_size": 64,       # Head size for multi-headed WKV (v5+)
}

# Training hyperparameters
train_config = {
    "learning_rate": 6e-4,          # Peak LR
    "lr_schedule": "cosine",        # Cosine decay
    "warmup_steps": 1000,
    "batch_size": 16,
    "accumulate_grad_batches": 4,   # Effective batch = 64
    "max_steps": 100000,
    "precision": "bf16-mixed",      # BF16 mixed precision
    "gradient_clip_val": 1.0,
    "weight_decay": 0.1,
    "beta1": 0.9,
    "beta2": 0.99,
}

# DeepSpeed configuration for multi-GPU
deepspeed_config = {
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {"device": "cpu"},
    },
    "bf16": {"enabled": True},
}

RWKV-PEFTによるLoRAファインチューニング

コミュニティが開発したRWKV-PEFTプロジェクトは、RWKVモデルの効率的なLoRAファインチューニングを提供します:

# Clone RWKV-PEFT
git clone https://github.com/JL-er/RWKV-PEFT.git
cd RWKV-PEFT

# Prepare training data in binidx format
python tools/preprocess_data.py \
    --input your_data.jsonl \
    --output-prefix train_data \
    --tokenizer-type RWKVTokenizer \
    --vocab-size 65536

# Launch LoRA fine-tuning
python train.py \
    --load_model /path/to/rwkv-base-model.pth \
    --proj_dir output/ \
    --data_file train_data \
    --data_type binidx \
    --ctx_len 2048 \
    --n_layer 24 \
    --n_embd 2048 \
    --lora_r 64 \
    --lora_alpha 128 \
    --lora_parts att,ffn \
    --micro_bsz 4 \
    --epoch_steps 1000 \
    --epoch_count 5 \
    --lr_init 2e-4 \
    --lr_final 2e-5 \
    --strategy deepspeed_stage_1 \
    --precision bf16

学習後、LoRA重みをマージします:

python merge_lora.py \
    --base_model /path/to/rwkv-base-model.pth \
    --lora_checkpoint output/rwkv-lora-final.pth \
    --output merged_model.pth \
    --lora_r 64 \
    --lora_alpha 128

HuggingFace連携

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load RWKV from HuggingFace Hub
model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-6-world-7b",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
    "RWKV/rwkv-6-world-7b",
    trust_remote_code=True
)

# Generate text
prompt = "The RWKV architecture is interesting because"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=200,
    temperature=0.8,
    top_p=0.9,
    do_sample=True
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

リソース要件

Fine-tuning VRAM requirements (LoRA, ctx_len=2048):

Model Size  | LoRA r=8  | LoRA r=64 | Full Fine-tune
------------|-----------|-----------|----------------
1.5B        |   4 GB    |   6 GB    |   16 GB
3B          |   6 GB    |  10 GB    |   28 GB
7B          |  12 GB    |  18 GB    |   56 GB
14B         |  22 GB    |  34 GB    |  112 GB

失敗ケースと制限事項

RWKVがどこで苦戦するかを理解することは、その強みを知ることと同じくらい重要です。

1. Needle-in-a-Haystack検索

RWKVの固定サイズ状態は情報の圧縮を意味します。長いコンテキストの奥深くに埋め込まれた特定の情報を正確に検索できない場合があります:

Task: "Find the phone number mentioned on page 37 of a 100-page document"

Transformer (GPT-4 class): Correctly retrieves the number (full attention)
RWKV-6 7B:                 Often fails or hallucinates a similar number
Mamba 7B:                  Sometimes succeeds (selective state helps)

Root cause: The exponential decay in WKV means old information is
progressively "forgotten" unless it strongly activates key channels.
The fixed state size (d dimensions) cannot store arbitrary facts.

2. 複雑なマルチホップ推論

複数の離れたコンテキスト部分への同時参照を必要とするタスクは困難です:

Prompt: "Alice gave Bob a red ball. Charlie gave Diana a blue cube.
         ... (500 tokens of distraction) ...
         Eve traded her green cone for the object Diana has.
         What color is the object Eve now has?"

Transformer: "Blue" (correct -- attends to both relevant sentences)
RWKV:        "Green" or "Red" (may lose track of multi-hop chain)

3. プロンプトフォーマットへの感度

RWKVモデルはTransformerよりもプロンプトフォーマットに顕著に敏感です。RNNの順序付けられた逐次的な性質により、情報の提示方法がより重要になります:

# This format works well with RWKV
prompt_good = "User: What is the capital of France?\n\nAssistant:"

# This format may produce worse results
prompt_bad = "capital of france?"

# RWKV is sensitive to newlines, spacing, and role markers.
# Always use consistent chat templates when deploying RWKV models.

Transformerはアテンションの置換不変性により、プロンプトの変動に対して自然に鈍感です。RWKVのリカレンスの順序的な性質により、トークンの配列方法に対して本質的により敏感になります。

4. 状態サイズ vs 情報容量

固定された隠れ状態は根本的なボトルネックを生み出します。埋め込み次元 d = 2048h = 32 ヘッド(RWKV-v5以降)のモデルでは、レイヤーあたりの総状態は h * (d/h)^2 = 32 * 64^2 = 131,072 個の浮動小数点数です。かなりの量ですが有限であり、TransformerのKVキャッシュのように入力長に応じてスケールすることはできません。

5. エコシステムとツーリングのギャップ

コミュニティサポートの成長にもかかわらず、RWKVのエコシステムはTransformerのエコシステムよりまだ小さいです。事前学習済みチェックポイントの数が少なく、RLHF/DPOチューニング済みバリアントも限られ、デプロイ用のツーリング(Transformer向けのvLLM、TensorRT-LLMと比較して)も少ないことが、採用への実際的な障壁として残っています。

実践的なデプロイのヒント

長い会話における状態管理

class RWKVChatSession:
    """Manage RWKV state across a multi-turn conversation."""

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.state = None  # Will hold the running RNN state

    def chat(self, user_message):
        # Format the message
        prompt = f"User: {user_message}\n\nAssistant:"
        input_ids = self.tokenizer.encode(prompt)

        # Feed tokens through model, updating state
        output_ids = []
        for token_id in input_ids:
            logits, self.state = self.model.forward(
                token_id, self.state
            )

        # Generate response tokens
        for _ in range(500):
            token_id = self.sample(logits, temperature=0.8)
            if token_id == self.tokenizer.eos_token_id:
                break
            output_ids.append(token_id)
            logits, self.state = self.model.forward(
                token_id, self.state
            )

        return self.tokenizer.decode(output_ids)

    def save_state(self, path):
        """Save conversation state to disk for later resumption."""
        torch.save(self.state, path)

    def load_state(self, path):
        """Resume conversation from saved state."""
        self.state = torch.load(path)

    @staticmethod
    def sample(logits, temperature=1.0, top_p=0.9):
        probs = torch.softmax(logits / temperature, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_probs, dim=-1)
        mask = cumsum - sorted_probs > top_p
        sorted_probs[mask] = 0.0
        sorted_probs /= sorted_probs.sum()
        idx = torch.multinomial(sorted_probs, 1)
        return sorted_indices[idx].item()

エッジデプロイ用の量子化

RWKVモデルはそのシンプルなアーキテクチャ(保存すべき複雑なアテンションパターンがない)により、量子化に適しています:

# Using the rwkv.cpp project for CPU inference
git clone https://github.com/saharNooby/rwkv.cpp.git
cd rwkv.cpp

# Quantize model to INT4
python python/convert_pytorch_to_ggml.py \
    /path/to/rwkv-model.pth \
    /path/to/output.bin \
    FP16

python python/quantize.py \
    /path/to/output.bin \
    /path/to/output-q4_0.bin \
    Q4_0

# Run inference on CPU
python python/chat.py \
    --model /path/to/output-q4_0.bin

まとめ

RWKVはシーケンスモデルの設計空間において真に新しい地点を示しています。Transformerの並列化可能な学習とRNNの定数メモリ推論を組み合わせることで、推論効率が重要なシナリオ -- エッジデプロイ、長コンテキストアプリケーション、リアルタイム生成、リソース制約のある環境 -- に対して魅力的な代替案を提供します。

このアーキテクチャにはトレードオフもあります。Needle-in-a-Haystack検索、離れたコンテキストにわたるマルチホップ推論、プロンプト感度は、Transformerが依然として優れている領域です。しかし、v4からv7 Gooseまでの急速な進化は、RWKVコミュニティがこれらの制限に積極的に取り組んでいることを示しており、動的状態進化と一般化デルタルールが表現力を大幅に向上させています。

実務者にとって、RWKVは以下の場合に真剣に検討する価値があります:

  • 推論コストまたはレイテンシが主要な関心事である場合
  • コンテキスト長が定期的に32Kトークンを超える場合
  • デプロイ対象にエッジデバイスやVRAM制限のあるGPUが含まれる場合
  • ストリーミングまたはリアルタイムテキスト生成が必要な場合

効率的なシーケンスモデリングの分野は急速に進化しており、RWKV、Mamba、ハイブリッドアーキテクチャ(MambaとTransformerレイヤーを組み合わせたJambaなど)がすべて純粋なTransformerの王座に挑戦しています。「Transformerで学習されたRNN」というRWKVのユニークな位置づけは、アーキテクチャが成熟し続ける中で関連性を保ち続ける明確な利点を与えています。

参考文献