Skip to content

필사 모드: RWKVアーキテクチャ徹底解説:Transformerに匹敵する線形アテンションRNN

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

はじめに

2017年にVaswaniらがTransformerを発表して以来、自己アテンションメカニズムはシーケンスモデリングにおける支配的なパラダイムとなりました。しかし、Transformerには根本的な負担が伴います。シーケンス長に対する二次的な時間・メモリ計算量 `O(L^2 * d)` は、長いシーケンスでの推論を高コストにし、増大し続けるKVキャッシュを必要とします。リカレントニューラルネットワーク(RNN)はステップあたり `O(d)` の定数メモリ推論を提供しますが、LSTMのような古典的なRNNは勾配消失問題、逐次的な学習ボトルネック、数億パラメータ以上へのスケーリングの困難さに悩まされてきました。

RWKV(「ルワクブ」と発音)-- Receptance Weighted Key Value -- は、Bo Peng(BlinkDL)によって設計された新しいアーキテクチャであり、この溝を埋めるものです。アテンションを線形リカレンスとして再定式化することで、RWKVは **Transformerレベルの言語モデリング品質** と **RNNレベルの推論効率** を両立します。学習の総計算量は `O(T*d)`(シーケンス長に対して線形)であり、推論はKVキャッシュなしでステップあたり `O(d)` です。このアーキテクチャは14Bパラメータ(RWKV-4 Eagle)までスケーリングされ、これは史上最大の密なRNNとなっています。ベンチマークでは、同規模のTransformerと標準的なNLPタスクで同等の性能を示しています。

RWKV論文はEMNLP 2023 Findings(arXiv: 2305.13048)で発表され、アーキテクチャはバージョン5(Eagle)、6(Finch)、7(Goose)と進化を続けており、各イテレーションでより表現力豊かな状態進化メカニズムが導入されています。

本記事では、WKVメカニズムの数学的基盤、アーキテクチャの完全なウォークスルー、Transformer・Mamba・LSTMとの比較、実践的な学習・デプロイガイド、および既知の制限事項について解説します。

RWKVアーキテクチャ概要

高レベル構造

RWKVはN個の残差ブロックを積み重ねた構造で、各ブロックには2つのサブレイヤーが含まれます:**Time Mixing**(アテンションに相当)と **Channel Mixing**(FFNに相当)です。Transformerとは異なり、位置エンコーディングは存在しません。時間的な情報はリカレントなWKVメカニズムと学習された時間減衰パラメータによって暗黙的に捉えられます。

+--------------------------------------------------+

| RWKV Block (x N) |

| |

| +--------------------------------------------+ |

| | Layer Norm | |

| +---------------------+----------------------+ |

| | |

| +---------------------v----------------------+ |

| | Time Mixing (WKV Attention) | |

| | | |

| | x_t --+-- R (Receptance) -- sigmoid(r) | |

| | +-- K (Key) | |

| | +-- V (Value) | |

| | +-- W (Time Decay, learned) | |

| | | |

| | wkv_t = weighted_sum(K, V, W, U) | |

| | out_t = sigmoid(r_t) * wkv_t | |

| +---------------------+----------------------+ |

| | + residual |

| +---------------------v----------------------+ |

| | Layer Norm | |

| +---------------------+----------------------+ |

| | |

| +---------------------v----------------------+ |

| | Channel Mixing (FFN) | |

| | | |

| | x_t --+-- R (Receptance) -- sigmoid(r) | |

| | +-- K (Key) -- squared_relu(k) | |

| | | |

| | out_t = sigmoid(r_t) * (W_v * relu2(k)) | |

| +---------------------+----------------------+ |

| | + residual |

+--------------------------------------------------+

トークンシフト:秘密の要素

各サブレイヤーでR、K、Vを計算する前に、RWKVは**トークンシフト**(タイムシフトまたは線形補間とも呼ばれる)を適用します。現在のトークン埋め込み `x_t` のみを使用する代わりに、RWKVはそれを前のトークンと混合します:

`x'_t = mu * x_t + (1 - mu) * x_{t-1}`

ここで `mu` はチャネルごとの学習可能な補間重みです。この単純な操作により、キーとバリューの計算前にバイグラムレベルの情報にアクセスでき、完全なアテンションの不在を補う安価な形のローカルコンテキストを提供します。

WKVメカニズム:数学的定式化

コア計算

WKV(Weighted Key Value)演算子はRWKVの核心です。softmaxアテンションを指数関数的に減衰する重み付き和で置き換えます。位置tにおけるWKV出力は以下の通りです:

wkv_t = (sum_{i=1}^{t-1} exp(-(t-1-i)*w + k_i) * v_i + exp(u + k_t) * v_t)

/ (sum_{i=1}^{t-1} exp(-(t-1-i)*w + k_i) + exp(u + k_t))

ここで:

- `w` は **時間減衰** パラメータ(チャネルごと、学習される)、`w = exp(decay)` により常に正

- `u` は現在のトークンに追加の重みを与える **ボーナス** パラメータ

- `k_i, v_i` は位置iにおけるキーとバリュー

分子は過去の全バリューの重み付き和を蓄積し、各過去トークンの重みは距離に応じて指数関数的に減衰します。現在のトークンには通常の減衰の代わりに特別なボーナス `u` が適用されます。

リカレント定式化

RNNモード推論を可能にする重要な洞察は、WKV計算をリカレンスとして表現できることです。累積分子 `alpha_t` と累積分母 `beta_t` を定義します:

alpha_t = exp(-w) * alpha_{t-1} + exp(k_t) * v_t

beta_t = exp(-w) * beta_{t-1} + exp(k_t)

すると、各ステップの出力は以下のようになります:

wkv_t = (exp(-w) * alpha_{t-1} + exp(u + k_t) * v_t)

/ (exp(-w) * beta_{t-1} + exp(u + k_t))

これにはステップあたり `O(d)` の計算と `O(d)` のメモリのみが必要で、累積状態alphaとbetaはそれぞれd次元ベクトルです。

対数空間計算による数値安定性

指数関数の直接計算はオーバーフローする可能性があります。RWKVは数値安定性を維持するために対数空間のトリックを使用します:

def rwkv_wkv_single_step(w, u, k, v, alpha_prev, beta_prev, log_max_prev):

"""

Single-step WKV computation with numerical stability.

Args:

w: time decay (d,) -- positive values

u: bonus parameter (d,)

k: key vector (d,)

v: value vector (d,)

alpha_prev: running numerator state (d,)

beta_prev: running denominator state (d,)

log_max_prev: log of max exponent for stability (d,)

Returns:

wkv: output vector (d,)

alpha_new, beta_new, log_max_new: updated states

"""

Compute log-space exponents

log_ew = -w # log(e^(-w))

log_ek = k # log(e^k)

For the past terms: e^(-w) * prev = e^(log_ew + log_max_prev)

log_past = log_ew + log_max_prev

For the current term: e^(u + k)

log_curr = u + k

Numerically stable log-sum-exp

log_max_new = torch.max(log_past, log_curr)

past_scale = torch.exp(log_past - log_max_new)

curr_scale = torch.exp(log_curr - log_max_new)

wkv = (past_scale * alpha_prev + curr_scale * v) / \

(past_scale * beta_prev + curr_scale)

Update running states

log_max_alpha = torch.max(log_past, log_ek)

alpha_new = torch.exp(log_past - log_max_alpha) * alpha_prev + \

torch.exp(log_ek - log_max_alpha) * v

beta_new = torch.exp(log_past - log_max_alpha) * beta_prev + \

torch.exp(log_ek - log_max_alpha)

return wkv, alpha_new, beta_new, log_max_alpha

並列学習の定式化

学習時には、WKVをプレフィックスサム(スキャン)操作を使用してシーケンス全体にわたって並列に計算できます。`exp(-w)` がチャネルごとの固定減衰として機能するため、累積重みは効率的に計算可能な等比級数を形成します:

def rwkv_wkv_parallel(w, u, k, v):

"""

Parallel WKV computation for training.

Args:

w: time decay (d,)

u: bonus (d,)

k: keys (T, d)

v: values (T, d)

Returns:

wkv: output (T, d)

"""

T, d = k.shape

ew = torch.exp(-w) # per-channel decay factor (d,)

ek = torch.exp(k) # (T, d)

ekv = ek * v # (T, d)

Conceptual O(T*d) scan -- in practice uses a custom CUDA kernel

alpha = torch.zeros(d, device=k.device)

beta = torch.zeros(d, device=k.device)

wkv = torch.zeros(T, d, device=k.device)

for t in range(T):

Current token with bonus

euk = torch.exp(u + k[t])

wkv[t] = (alpha + euk * v[t]) / (beta + euk)

Update running sums (without bonus)

alpha = ew * alpha + ek[t] * v[t]

beta = ew * beta + ek[t]

return wkv

実際には、RWKVはこれらの操作を高効率なスキャンに融合するカスタムCUDAカーネルを使用し、GPU上で準線形的なスピードアップを実現しています。

二重性:Transformerモード vs RNNモード

RWKVの最もエレガントな側面の一つは、同じモデルが2つのモードで動作できることです:

+-----------------------------------------------------------+

| RWKV Dual Operation Modes |

+---------------------------+-------------------------------+

| Training Mode | Inference Mode |

| (Transformer-like) | (RNN-like) |

| | |

| +---+---+---+---+ | +---+ |

| |t=1|t=2|t=3|t=4| | |t=n| |

| +-+-+-+-+-+-+-+-++ | +-+-+ |

| | | | | | | |

| v v v v | v |

| +---------------+ | +-------+ +---------+ |

| | Parallel Scan | | | WKV |-->| State | |

| | (all at once) | | | step | | (a, b) | |

| +-------+-------+ | +---+---+ +---------+ |

| | | | |

| v | v |

| +---------------+ | +-----------+ |

| | O(Td) total | | | O(d)/step | |

| | parallelized | | | constant | |

| +---------------+ | +-----------+ |

+---------------------------+-------------------------------+

| Same weights, same results, different compute pattern |

+-----------------------------------------------------------+

**学習時**:スキャン定式化を使用してシーケンス全体を並列に処理します。これはリカレンスと数学的に等価ですが、GPUの並列性を活用できます。計算量:合計 `O(Td)`。

**推論時**:1トークンずつ処理し、固定サイズの隠れ状態を更新します。KVキャッシュは不要です。計算量:トークンあたり `O(d)`、コンテキスト長に関係なくメモリは `O(d)`。

この二重性がRWKVの根本的な利点です。妥協なく両方の世界のベストを得ることができます。

アーキテクチャ比較

RWKV vs Transformer vs Mamba vs LSTM

| 特徴 | RWKV (v4/v5/v6) | Transformer | Mamba (S6) | LSTM |

| -------------------------- | ---------------- | ---------------------- | ---------------- | -------------------- |

| **学習計算量** | O(Td) | O(T^2 d) | O(Td) | O(Td) |

| **推論(ステップあたり)** | O(d) | O(Td) KVキャッシュ付き | O(d) | O(d) |

| **メモリ(推論)** | O(d) 定数 | O(T\*d) 増加 | O(d) 定数 | O(d) 定数 |

| **並列学習可能** | はい(スキャン) | はい(行列積) | はい(スキャン) | いいえ(逐次的) |

| **KVキャッシュ必要** | いいえ | はい | いいえ | いいえ |

| **最大学習規模** | 14B(Eagle) | 1.8T+(GPT-4クラス) | 8B(Mamba-2) | 約1B |

| **長距離記憶** | 良好(減衰型) | 優秀 | 良好(選択的) | 不良 |

| **位置エンコーディング** | なし(暗黙的) | 必要 | なし(暗黙的) | なし(暗黙的) |

| **アテンションパターン** | 線形減衰 | 完全な二次 | 選択的SSM | ゲート付きリカレンス |

| **コンテキスト長** | 無制限(理論上) | 固定ウィンドウ | 無制限(理論上) | 制限あり(勾配消失) |

| **HuggingFaceサポート** | あり | あり | あり | あり |

| **エコシステム成熟度** | 成長中 | 支配的 | 成長中 | 成熟(レガシー) |

| **Needle-in-Haystack** | 弱い | 強い | 中程度 | 非常に弱い |

主要なトレードオフ

**RWKV vs Transformer**:RWKVは推論効率で決定的に勝ります(定数メモリ、KVキャッシュ不要、線形生成時間)。Transformerはコンテキスト内の任意の位置からの精密な検索を必要とするタスク(needle-in-a-haystack)で勝ります。ほとんどの生成タスクでは、同等規模での品質差は小さいです。

**RWKV vs Mamba**:両者とも線形計算量と定数メモリ推論を実現します。Mambaは入力依存のSSMパラメータ(選択メカニズム)を使用し、RWKVは学習済み補間を持つ固定時間減衰を使用します。Mambaは強いコンテンツベースの選択を必要とするタスクでやや優れる傾向がありますが、RWKVはより成熟したエコシステムを持ち、より大きなスケールまで学習されています。RWKV-7 Gooseは動的状態進化によりこのギャップを大幅に縮めています。

**RWKV vs LSTM**:RWKVは厳密に上位互換です。並列学習が可能(LSTMは不可)、数十億パラメータまでスケール可能、はるかに良いパープレキシティを達成します。WKVメカニズムはゲート付きリカレンスのより表現力豊かな形態です。

RWKVバージョンの進化

v4からv7まで

Version evolution summary

rwkv_versions = {

"v4 (Pile)": {

"year": 2023,

"max_params": "14B",

"key_feature": "Original WKV with fixed time decay",

"state_size": "d per layer (scalar decay)"

},

"v5 (Eagle)": {

"year": 2024,

"max_params": "7.5B",

"key_feature": "Multi-headed WKV, matrix-valued states",

"state_size": "h * (d/h)^2 per layer"

},

"v6 (Finch)": {

"year": 2024,

"max_params": "7.5B",

"key_feature": "Data-dependent time decay, LoRA-style mixing",

"state_size": "h * (d/h)^2 per layer"

},

"v7 (Goose)": {

"year": 2025,

"max_params": "2.9B (scaling ongoing)",

"key_feature": "Dynamic state evolution via generalized delta rule",

"state_size": "h * (d/h)^2 per layer (dynamic)"

}

}

RWKV-7 Gooseは特に重要なステップです。ベクトル値ゲーティングとインコンテキスト学習率を持つ**一般化デルタルール**を導入しています。これにより隠れ状態が入力内容に基づいて動的に進化でき、固定された線形アテンションのTC0表現力の制限を克服します。Pileベンチマークの3Bパラメータ規模では、RWKV-7はパープレキシティ9.6を達成し、Transformerの9.8やRWKV-6の9.9を上回っています。

推論パフォーマンスベンチマーク

トークン生成速度

以下のベンチマークは単一のNVIDIA A100 80GB GPU上で、3Bパラメータ規模のモデルを比較して収集されました:

Tokens/sec at various sequence lengths (3B params, A100 GPU):

Seq Length | RWKV-6 | Transformer | Mamba-2 | Note

------------|---------|-------------|----------|------------------

512 | 2,400 | 2,200 | 2,500 | All comparable

2,048 | 2,350 | 1,800 | 2,450 | Transformer slowing

8,192 | 2,300 | 950 | 2,400 | KV cache pressure

32,768 | 2,250 | 280 | 2,350 | Transformer struggling

131,072 | 2,200 | OOM | 2,300 | Transformer OOM

524,288 | 2,100 | OOM | 2,200 | Both linear models OK

Memory Usage (inference state):

Model | 1K ctx | 8K ctx | 128K ctx | 1M ctx

------------|---------|---------|----------|--------

RWKV-6 3B | 6.2 GB | 6.2 GB | 6.2 GB | 6.2 GB (constant!)

Transformer | 6.5 GB | 8.1 GB | 32.4 GB | OOM

Mamba-2 3B | 6.3 GB | 6.3 GB | 6.3 GB | 6.3 GB (constant)

重要なポイント:RWKVはコンテキスト長に関係なく定数メモリと準定数速度を維持する一方、Transformerは8Kトークンを超えると急速に劣化します。

パープレキシティ比較(The Pile)

Model | Params | Pile Val PPL | LAMBADA | HellaSwag

--------------------|--------|--------------|---------|----------

RWKV-4 | 7B | 8.28 | 67.2% | 52.5%

RWKV-5 Eagle | 7B | 8.15 | 68.1% | 53.2%

RWKV-6 Finch | 7B | 8.05 | 69.0% | 54.1%

RWKV-7 Goose | 2.9B | 9.60 | 65.8% | 51.0%

Pythia (Transformer) | 6.9B | 8.25 | 67.1% | 52.0%

LLaMA-like | 7B | 7.95 | 73.0% | 56.4%

Mamba | 2.8B | 9.80 | 64.9% | 50.3%

同等規模では、RWKVはTransformerベースラインと競争力があります。ただし、広範なデータキュレーションを行った最良のTransformerモデル(LLaMAファミリーなど)はまだわずかに優位を保っています。

学習とファインチューニングの実践ガイド

環境セットアップ

Clone the RWKV-LM repository

git clone https://github.com/BlinkDL/RWKV-LM.git

cd RWKV-LM/RWKV-v7

Install dependencies

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

pip install lightning deepspeed wandb ninja

For custom CUDA kernel compilation

pip install triton

Verify CUDA kernel builds

python -c "from rwkv.model import RWKV; print('RWKV loaded successfully')"

スクラッチからのフル学習

train_config.py -- Example training configuration

from rwkv.model import RWKV

Model configuration

model_config = {

"n_layer": 24, # Number of RWKV blocks

"n_embd": 2048, # Embedding dimension

"vocab_size": 65536, # Vocabulary size

"ctx_len": 4096, # Context length for training

"head_size": 64, # Head size for multi-headed WKV (v5+)

}

Training hyperparameters

train_config = {

"learning_rate": 6e-4, # Peak LR

"lr_schedule": "cosine", # Cosine decay

"warmup_steps": 1000,

"batch_size": 16,

"accumulate_grad_batches": 4, # Effective batch = 64

"max_steps": 100000,

"precision": "bf16-mixed", # BF16 mixed precision

"gradient_clip_val": 1.0,

"weight_decay": 0.1,

"beta1": 0.9,

"beta2": 0.99,

}

DeepSpeed configuration for multi-GPU

deepspeed_config = {

"zero_optimization": {

"stage": 2,

"offload_optimizer": {"device": "cpu"},

},

"bf16": {"enabled": True},

}

RWKV-PEFTによるLoRAファインチューニング

コミュニティが開発したRWKV-PEFTプロジェクトは、RWKVモデルの効率的なLoRAファインチューニングを提供します:

Clone RWKV-PEFT

git clone https://github.com/JL-er/RWKV-PEFT.git

cd RWKV-PEFT

Prepare training data in binidx format

python tools/preprocess_data.py \

--input your_data.jsonl \

--output-prefix train_data \

--tokenizer-type RWKVTokenizer \

--vocab-size 65536

Launch LoRA fine-tuning

python train.py \

--load_model /path/to/rwkv-base-model.pth \

--proj_dir output/ \

--data_file train_data \

--data_type binidx \

--ctx_len 2048 \

--n_layer 24 \

--n_embd 2048 \

--lora_r 64 \

--lora_alpha 128 \

--lora_parts att,ffn \

--micro_bsz 4 \

--epoch_steps 1000 \

--epoch_count 5 \

--lr_init 2e-4 \

--lr_final 2e-5 \

--strategy deepspeed_stage_1 \

--precision bf16

学習後、LoRA重みをマージします:

python merge_lora.py \

--base_model /path/to/rwkv-base-model.pth \

--lora_checkpoint output/rwkv-lora-final.pth \

--output merged_model.pth \

--lora_r 64 \

--lora_alpha 128

HuggingFace連携

from transformers import AutoModelForCausalLM, AutoTokenizer

Load RWKV from HuggingFace Hub

model = AutoModelForCausalLM.from_pretrained(

"RWKV/rwkv-6-world-7b",

torch_dtype="auto",

device_map="auto"

)

tokenizer = AutoTokenizer.from_pretrained(

"RWKV/rwkv-6-world-7b",

trust_remote_code=True

)

Generate text

prompt = "The RWKV architecture is interesting because"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(

**inputs,

max_new_tokens=200,

temperature=0.8,

top_p=0.9,

do_sample=True

)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

リソース要件

Fine-tuning VRAM requirements (LoRA, ctx_len=2048):

Model Size | LoRA r=8 | LoRA r=64 | Full Fine-tune

------------|-----------|-----------|----------------

1.5B | 4 GB | 6 GB | 16 GB

3B | 6 GB | 10 GB | 28 GB

7B | 12 GB | 18 GB | 56 GB

14B | 22 GB | 34 GB | 112 GB

失敗ケースと制限事項

RWKVがどこで苦戦するかを理解することは、その強みを知ることと同じくらい重要です。

1. Needle-in-a-Haystack検索

RWKVの固定サイズ状態は情報の圧縮を意味します。長いコンテキストの奥深くに埋め込まれた特定の情報を正確に検索できない場合があります:

Task: "Find the phone number mentioned on page 37 of a 100-page document"

Transformer (GPT-4 class): Correctly retrieves the number (full attention)

RWKV-6 7B: Often fails or hallucinates a similar number

Mamba 7B: Sometimes succeeds (selective state helps)

Root cause: The exponential decay in WKV means old information is

progressively "forgotten" unless it strongly activates key channels.

The fixed state size (d dimensions) cannot store arbitrary facts.

2. 複雑なマルチホップ推論

複数の離れたコンテキスト部分への同時参照を必要とするタスクは困難です:

Prompt: "Alice gave Bob a red ball. Charlie gave Diana a blue cube.

... (500 tokens of distraction) ...

Eve traded her green cone for the object Diana has.

What color is the object Eve now has?"

Transformer: "Blue" (correct -- attends to both relevant sentences)

RWKV: "Green" or "Red" (may lose track of multi-hop chain)

3. プロンプトフォーマットへの感度

RWKVモデルはTransformerよりもプロンプトフォーマットに顕著に敏感です。RNNの順序付けられた逐次的な性質により、情報の提示方法がより重要になります:

This format works well with RWKV

prompt_good = "User: What is the capital of France?\n\nAssistant:"

This format may produce worse results

prompt_bad = "capital of france?"

RWKV is sensitive to newlines, spacing, and role markers.

Always use consistent chat templates when deploying RWKV models.

Transformerはアテンションの置換不変性により、プロンプトの変動に対して自然に鈍感です。RWKVのリカレンスの順序的な性質により、トークンの配列方法に対して本質的により敏感になります。

4. 状態サイズ vs 情報容量

固定された隠れ状態は根本的なボトルネックを生み出します。埋め込み次元 `d = 2048`、`h = 32` ヘッド(RWKV-v5以降)のモデルでは、レイヤーあたりの総状態は `h * (d/h)^2 = 32 * 64^2 = 131,072` 個の浮動小数点数です。かなりの量ですが有限であり、TransformerのKVキャッシュのように入力長に応じてスケールすることはできません。

5. エコシステムとツーリングのギャップ

コミュニティサポートの成長にもかかわらず、RWKVのエコシステムはTransformerのエコシステムよりまだ小さいです。事前学習済みチェックポイントの数が少なく、RLHF/DPOチューニング済みバリアントも限られ、デプロイ用のツーリング(Transformer向けのvLLM、TensorRT-LLMと比較して)も少ないことが、採用への実際的な障壁として残っています。

実践的なデプロイのヒント

長い会話における状態管理

class RWKVChatSession:

"""Manage RWKV state across a multi-turn conversation."""

def __init__(self, model, tokenizer):

self.model = model

self.tokenizer = tokenizer

self.state = None # Will hold the running RNN state

def chat(self, user_message):

Format the message

prompt = f"User: {user_message}\n\nAssistant:"

input_ids = self.tokenizer.encode(prompt)

Feed tokens through model, updating state

output_ids = []

for token_id in input_ids:

logits, self.state = self.model.forward(

token_id, self.state

)

Generate response tokens

for _ in range(500):

token_id = self.sample(logits, temperature=0.8)

if token_id == self.tokenizer.eos_token_id:

break

output_ids.append(token_id)

logits, self.state = self.model.forward(

token_id, self.state

)

return self.tokenizer.decode(output_ids)

def save_state(self, path):

"""Save conversation state to disk for later resumption."""

torch.save(self.state, path)

def load_state(self, path):

"""Resume conversation from saved state."""

self.state = torch.load(path)

@staticmethod

def sample(logits, temperature=1.0, top_p=0.9):

probs = torch.softmax(logits / temperature, dim=-1)

sorted_probs, sorted_indices = torch.sort(probs, descending=True)

cumsum = torch.cumsum(sorted_probs, dim=-1)

mask = cumsum - sorted_probs > top_p

sorted_probs[mask] = 0.0

sorted_probs /= sorted_probs.sum()

idx = torch.multinomial(sorted_probs, 1)

return sorted_indices[idx].item()

エッジデプロイ用の量子化

RWKVモデルはそのシンプルなアーキテクチャ(保存すべき複雑なアテンションパターンがない)により、量子化に適しています:

Using the rwkv.cpp project for CPU inference

git clone https://github.com/saharNooby/rwkv.cpp.git

cd rwkv.cpp

Quantize model to INT4

python python/convert_pytorch_to_ggml.py \

/path/to/rwkv-model.pth \

/path/to/output.bin \

FP16

python python/quantize.py \

/path/to/output.bin \

/path/to/output-q4_0.bin \

Q4_0

Run inference on CPU

python python/chat.py \

--model /path/to/output-q4_0.bin

まとめ

RWKVはシーケンスモデルの設計空間において真に新しい地点を示しています。Transformerの並列化可能な学習とRNNの定数メモリ推論を組み合わせることで、推論効率が重要なシナリオ -- エッジデプロイ、長コンテキストアプリケーション、リアルタイム生成、リソース制約のある環境 -- に対して魅力的な代替案を提供します。

このアーキテクチャにはトレードオフもあります。Needle-in-a-Haystack検索、離れたコンテキストにわたるマルチホップ推論、プロンプト感度は、Transformerが依然として優れている領域です。しかし、v4からv7 Gooseまでの急速な進化は、RWKVコミュニティがこれらの制限に積極的に取り組んでいることを示しており、動的状態進化と一般化デルタルールが表現力を大幅に向上させています。

実務者にとって、RWKVは以下の場合に真剣に検討する価値があります:

- 推論コストまたはレイテンシが主要な関心事である場合

- コンテキスト長が定期的に32Kトークンを超える場合

- デプロイ対象にエッジデバイスやVRAM制限のあるGPUが含まれる場合

- ストリーミングまたはリアルタイムテキスト生成が必要な場合

効率的なシーケンスモデリングの分野は急速に進化しており、RWKV、Mamba、ハイブリッドアーキテクチャ(MambaとTransformerレイヤーを組み合わせたJambaなど)がすべて純粋なTransformerの王座に挑戦しています。「Transformerで学習されたRNN」というRWKVのユニークな位置づけは、アーキテクチャが成熟し続ける中で関連性を保ち続ける明確な利点を与えています。

参考文献

- [RWKV: Reinventing RNNs for the Transformer Era (arXiv: 2305.13048)](https://arxiv.org/abs/2305.13048)

- [RWKV-7 "Goose" with Expressive Dynamic State Evolution (arXiv: 2503.14456)](https://arxiv.org/abs/2503.14456)

- [BlinkDL/RWKV-LM -- 公式GitHubリポジトリ](https://github.com/BlinkDL/RWKV-LM)

- [RWKV Language Model Wiki(公式ドキュメント)](https://wiki.rwkv.com/)

- [Introducing RWKV -- Transformerの利点を持つRNN(HuggingFaceブログ)](https://huggingface.co/blog/rwkv)

- [HuggingFace上のRWKVモデルコレクション](https://huggingface.co/RWKV)

- [RWKVのサーベイ (arXiv: 2412.14847)](https://arxiv.org/abs/2412.14847)

- [The Full Stack -- RWKV解説](https://fullstackdeeplearning.com/blog/posts/rwkv-explainer/)

- [RWKV-PEFT: コミュニティファインチューニングプロジェクト](https://github.com/JL-er/RWKV-PEFT)

- [RWKV公式ウェブサイト](https://www.rwkv.com/)

クイズ

Q1:

「RWKVアーキテクチャ徹底解説:Transformerに匹敵する線形アテンションRNN」の主なトピックは何ですか?

WKVアテンションメカニズム、線形計算量の優位性、TransformerやMambaとの比較、学習手法、推論最適化、実践的なデプロイ戦略まで、RWKVアーキテクチャを包括的に解説します。

コア計算 WKV(Weighted Key

Value)演算子はRWKVの核心です。softmaxアテンションを指数関数的に減衰する重み付き和で置き換えます。位置tにおけるWKV出力は以下の通りです:

ここで: w は 時間減衰 パラメータ(チャネルごと、学習される)、w = exp(decay) により常に正 u

は現在のトークンに追加の重みを与える ボーナス パラメータ k_i, v_i は位置iにおけるキーとバリュー

分子は過去の全バリューの重み付き和を蓄積し、各過去トークンの重みは距離に応じて指数関数的に減衰します。

RWKVの最もエレガントな側面の一つは、同じモデルが2つのモードで動作できることです:

学習時:スキャン定式化を使用してシーケンス全体を並列に処理します。これはリカレンスと数学的に等価ですが、GPUの並列性を活用できます。計算量:合計

O(Td)。

推論時:1トークンずつ処理し、固定サイズの隠れ状態を更新します。KVキャッシュは不要です。計算量:トークンあたり

O(d)、コンテキスト長に関係なくメモリは O(d)。

この二重性がRWKVの根本的な利点です。妥協なく両方の世界のベストを得ることができます。

RWKV vs Transformer vs Mamba vs LSTM 主要なトレードオフ RWKV vs

Transformer:RWKVは推論効率で決定的に勝ります(定数メモリ、KVキャッシュ不要、線形生成時間)。Transformerはコンテキスト内の任意の位置からの精密な検索を必要とするタスク(needle-in-a-haystack)で勝ります。ほとんどの生成タスクでは、同等規模での品質差は小さいです。

RWKV vs Mamba:両者とも線形計算量と定数メモリ推論を実現します。

v4からv7まで RWKV-7

Gooseは特に重要なステップです。ベクトル値ゲーティングとインコンテキスト学習率を持つ一般化デルタルールを導入しています。これにより隠れ状態が入力内容に基づいて動的に進化でき、固定された線形アテンションのTC0表現力の制限を克服します。Pileベンチマークの3Bパラメータ規模では、RWKV-7はパープレキシティ9.6を達成し、Transformerの9.8やRWKV-6の9.9を上回っています。

현재 단락 (1/418)

2017年にVaswaniらがTransformerを発表して以来、自己アテンションメカニズムはシーケンスモデリングにおける支配的なパラダイムとなりました。しかし、Transformerには根本的な負...

작성 글자: 0원문 글자: 18,242작성 단락: 0/418