Skip to content

필사 모드: Mamba: Linear-Time Sequence Modeling with Selective State Spaces 論文分析

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

論文概要

- **タイトル**: Mamba: Linear-Time Sequence Modeling with Selective State Spaces

- **著者**: Albert Gu, Tri Dao (Carnegie Mellon University, Princeton University)

- **発表**: 2023年12月 (arXiv: 2312.00752)、ICML 2024 採択

- **コード**: [github.com/state-spaces/mamba](https://github.com/state-spaces/mamba)

動機:Transformerの限界

Transformerはシーケンスモデリングの王座に君臨していますが、根本的な問題を抱えています:

- **O(N²) 計算量**: Self-attentionはシーケンス長に対して二次計算量

- **長いシーケンスの処理限界**: 128K以上のコンテキストでメモリ・演算量が爆発的に増加

- **推論の非効率性**: KV Cacheがシーケンス長に比例して増加

State Space Model(SSM)は理論的にO(N)の計算量を持ちますが、従来のSSM(S4、H3など)は**入力に依存しない固定パラメータ**のため、言語モデリングではTransformerに及びませんでした。

核心アイデア:Selective State Spaces

従来SSMの問題点

従来のSSMは**Linear Time-Invariant(LTI)**システムです:

$$h'(t) = Ah(t) + Bx(t)$$

$$y(t) = Ch(t)$$

ここでA、B、Cは**固定された**行列です。入力が「hello」でも「world」でも同じ方法で処理します。これは**content-aware reasoning**が不可能であることを意味します。

Mambaの解決策:Selection Mechanism

Mambaの核心的イノベーションは、**B、C、Δを入力依存にする**ことです:

従来SSM:固定パラメータ

B = nn.Parameter(torch.randn(N)) # 学習されるが入力に依存しない

Mamba:入力依存パラメータ

B = nn.Linear(D, N)(x) # xに応じてBが変化

C = nn.Linear(D, N)(x) # xに応じてCが変化

delta = nn.Linear(D, D)(x) # xに応じてstep sizeも変化

これが**Selective** State Spaceの意味です。入力に応じて、どの情報を状態に保存し、どの情報を無視するかを**選択**します。

直感的理解

Selection Mechanismを直感的に理解すると:

- **Δ(delta)**: 「このトークンをどれほど重要視するか」 — 大きなΔは現在の入力を強く反映、小さなΔは以前の状態を維持

- **B**: 「この入力のどの部分を状態に記録するか」

- **C**: 「状態のどの部分を出力として取り出すか」

これはTransformerのattentionが果たす役割と類似していますが、**O(N)**の計算量で実現します。

ハードウェア最適化:Selective Scanアルゴリズム

Selective SSMは入力依存であるため、従来SSMの効率的な畳み込み(convolution)トリックを使用できません。Mambaはこの問題を**ハードウェアアウェアアルゴリズム**で解決します。

問題:HBM ↔ SRAM ボトルネック

GPUのメモリ階層:

┌────────────────────┐

│ SRAM (20MB) │ ← 非常に高速、非常に小さい

├────────────────────┤

│ HBM (80GB) │ ← 低速、大容量

└────────────────────┘

ナイーブな実装では中間状態をHBMに繰り返し保存・読み込みするため、**メモリバウンド**になります。

解決策:Kernel Fusion + Recomputation

MambaのSelective Scanアルゴリズム:

1. **SRAMで一括計算**: 離散化(discretization)、選択的スキャン、出力計算を1つのカーネルに融合

2. **中間状態を保存しない**: 順伝播では(N、B、C、Δ)のみHBMに保存し、逆伝播で中間状態を再計算

3. **結果**: FlashAttentionと同様のIO-aware最適化

簡略化したSelective Scan(概念コード)

def selective_scan(x, A, B, C, delta):

"""

x: (batch, length, d_model)

Returns: (batch, length, d_model)

"""

batch, length, d = x.shape

n = A.shape[1] # state dimension

Discretize A, B using delta

deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)

deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)

Sequential scan (実際にはCUDAカーネルで並列化)

h = torch.zeros(batch, d, n, device=x.device)

ys = []

for i in range(length):

h = deltaA[:, i] * h + deltaB_x[:, i]

y = (h * C[:, i].unsqueeze(1)).sum(-1)

ys.append(y)

return torch.stack(ys, dim=1)

Mambaブロックアーキテクチャ

MambaはH3アーキテクチャからインスピレーションを得て、簡素化されたブロック設計を採用しています:

Input (x)

├──── Linear projection (expand) ────┐

│ │

▼ ▼

Conv1D SiLU (gate)

│ │

▼ │

SiLU │

│ │

▼ │

SSM (selective scan) │

│ │

▼ │

×─────────────────────────────────────┘

Linear projection (reduce)

Output

PyTorchでの実装:

class MambaBlock(nn.Module):

def __init__(self, d_model, d_state=16, d_conv=4, expand=2):

super().__init__()

d_inner = d_model * expand

Input projection

self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

Conv1D

self.conv1d = nn.Conv1d(

d_inner, d_inner, d_conv,

padding=d_conv - 1, groups=d_inner

)

SSM parameters

self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, delta

self.dt_proj = nn.Linear(1, d_inner, bias=True)

A parameter (structured)

A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)

self.A_log = nn.Parameter(torch.log(A))

self.D = nn.Parameter(torch.ones(d_inner))

self.out_proj = nn.Linear(d_inner, d_model, bias=False)

def forward(self, x):

batch, length, _ = x.shape

Dual path

xz = self.in_proj(x)

x, z = xz.chunk(2, dim=-1)

Conv + activation

x = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)

x = F.silu(x)

SSM

y = self.ssm(x)

Gate and project

y = y * F.silu(z)

return self.out_proj(y)

実験結果

言語モデリング(The Pile)

| モデル | パラメータ | Perplexity |

| ------------- | ---------- | ---------- |

| Transformer++ | 1.3B | 8.94 |

| H3 | 1.3B | 9.19 |

| Hyena | 1.3B | 9.07 |

| RWKV-4 | 1.5B | 9.02 |

| **Mamba** | **1.4B** | **8.63** |

Mambaは同規模のTransformerよりも**低いperplexity**を達成しました。特に2.8Bスケールでは、TransFormer++に対して約**0.3 perplexity**の改善を示しました。

推論速度

| シーケンス長 | Transformer | Mamba |

| ------------ | ----------- | ----------- |

| 2K | 1x | 1x |

| 8K | 1x | 3倍高速 |

| 64K | 1x | **5倍高速** |

| 1M | OOM | 動作 |

Mambaはシーケンス長が長くなるほど、Transformerに対して圧倒的な速度優位性を示します。

Selective Copying & Induction Heads

論文は2つの核心的な合成タスクでSelection Mechanismの有効性を検証しています:

1. **Selective Copying**: 長いシーケンスから特定のトークンのみを選択的にコピー → 従来SSM(S4)は失敗、Mambaは成功

2. **Induction Heads**: 「A B ... A → B」パターンの認識 → 従来SSMは失敗、Mambaは成功

この2つのタスクは**content-aware reasoning**が必須であり、Selection Mechanismなしでは解決不可能です。

Mamba vs Transformer:いつどちらを使うべきか?

| 基準 | Transformer | Mamba |

| ------------------- | ---------------- | -------------- |

| 長いシーケンス | △ (O(N²)) | ◎ (O(N)) |

| 推論速度 | △ (KV Cache増加) | ◎ (固定状態) |

| In-context learning | ◎ | ○ |

| 学習並列化 | ◎ (完全並列) | ○ (scan並列化) |

| エコシステム/ツール | ◎ | △ |

後続研究:Mamba-2とハイブリッド

Mamba以降、多くの進展がありました:

- **Mamba-2** (Dao & Gu, 2024): SSMとAttentionの理論的な関連性を示す**State Space Duality(SSD)**フレームワークを提案

- **Jamba** (AI21, 2024): Transformer + Mambaのハイブリッド、256Kコンテキスト

- **Zamba** (Zyphra, 2024): 小規模ハイブリッドモデルでSOTAを達成

まとめ

Mambaは**Selection Mechanism**というシンプルながら強力なアイデアでSSMの限界を克服し、O(N)の計算量でTransformerに匹敵または凌駕する性能を達成しました。ハードウェアアウェアアルゴリズム設計は、理論的効率性を実際の速度に変換するための鍵でした。

Transformerが依然として主流ですが、Mambaとハイブリッドアーキテクチャは、長いシーケンス処理が重要な領域でますます存在感を高めています。

クイズ

Self-attentionのO(N²)時間・空間計算量です。シーケンス長が長くなるほど演算量とメモリが二乗で増加し、長いシーケンスの処理が非効率になります。

SSMのパラメータ(B、C、Δ)を入力依存にすることで、どの情報を状態に保存し、どの情報を無視するかを選択(select)できるという意味です。

Linear

Time-Invariant(LTI)特性のためです。パラメータが入力と無関係に固定されているため、content-aware

reasoningが不可能でした。

現在の入力をどれほど重要に反映するかを決定する「step

size」です。大きなΔは現在の入力を強く反映し、小さなΔは以前の状態を維持します。

入力依存パラメータのため、もはやLTIシステムではなくなるため、畳み込みに変換してFFTで並列計算するトリックが適用できません。

IO-aware最適化です。SRAMですべての計算を融合(kernel

fusion)し、中間状態をHBMに保存せず逆伝播で再計算(recomputation)します。

長いシーケンスから特定のトークンのみを選択的に記憶・コピーするcontent-aware

reasoning能力です。固定パラメータのSSMではこのタスクを解決できません。

SSMとAttentionが数学的に同一の演算の異なる表現であることを示すフレームワークです。これにより、両アプローチの利点を組み合わせるための理論的基盤を提供します。

현재 단락 (1/144)

- **タイトル**: Mamba: Linear-Time Sequence Modeling with Selective State Spaces

작성 글자: 0원문 글자: 5,536작성 단락: 0/144