- Published on
Mamba: Linear-Time Sequence Modeling with Selective State Spaces 論文分析
- Authors
- Name
- 論文概要
- 動機:Transformerの限界
- 核心アイデア:Selective State Spaces
- ハードウェア最適化:Selective Scanアルゴリズム
- Mambaブロックアーキテクチャ
- 実験結果
- Mamba vs Transformer:いつどちらを使うべきか?
- 後続研究:Mamba-2とハイブリッド
- まとめ
- クイズ

論文概要
- タイトル: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
- 著者: Albert Gu, Tri Dao (Carnegie Mellon University, Princeton University)
- 発表: 2023年12月 (arXiv: 2312.00752)、ICML 2024 採択
- コード: github.com/state-spaces/mamba
動機:Transformerの限界
Transformerはシーケンスモデリングの王座に君臨していますが、根本的な問題を抱えています:
- O(N²) 計算量: Self-attentionはシーケンス長に対して二次計算量
- 長いシーケンスの処理限界: 128K以上のコンテキストでメモリ・演算量が爆発的に増加
- 推論の非効率性: KV Cacheがシーケンス長に比例して増加
State Space Model(SSM)は理論的にO(N)の計算量を持ちますが、従来のSSM(S4、H3など)は入力に依存しない固定パラメータのため、言語モデリングではTransformerに及びませんでした。
核心アイデア:Selective State Spaces
従来SSMの問題点
従来のSSMは**Linear Time-Invariant(LTI)**システムです:
ここでA、B、Cは固定された行列です。入力が「hello」でも「world」でも同じ方法で処理します。これはcontent-aware reasoningが不可能であることを意味します。
Mambaの解決策:Selection Mechanism
Mambaの核心的イノベーションは、B、C、Δを入力依存にすることです:
# 従来SSM:固定パラメータ
B = nn.Parameter(torch.randn(N)) # 学習されるが入力に依存しない
# Mamba:入力依存パラメータ
B = nn.Linear(D, N)(x) # xに応じてBが変化
C = nn.Linear(D, N)(x) # xに応じてCが変化
delta = nn.Linear(D, D)(x) # xに応じてstep sizeも変化
これがSelective State Spaceの意味です。入力に応じて、どの情報を状態に保存し、どの情報を無視するかを選択します。
直感的理解
Selection Mechanismを直感的に理解すると:
- Δ(delta): 「このトークンをどれほど重要視するか」 — 大きなΔは現在の入力を強く反映、小さなΔは以前の状態を維持
- B: 「この入力のどの部分を状態に記録するか」
- C: 「状態のどの部分を出力として取り出すか」
これはTransformerのattentionが果たす役割と類似していますが、**O(N)**の計算量で実現します。
ハードウェア最適化:Selective Scanアルゴリズム
Selective SSMは入力依存であるため、従来SSMの効率的な畳み込み(convolution)トリックを使用できません。Mambaはこの問題をハードウェアアウェアアルゴリズムで解決します。
問題:HBM ↔ SRAM ボトルネック
GPUのメモリ階層:
┌────────────────────┐
│ SRAM (20MB) │ ← 非常に高速、非常に小さい
├────────────────────┤
│ HBM (80GB) │ ← 低速、大容量
└────────────────────┘
ナイーブな実装では中間状態をHBMに繰り返し保存・読み込みするため、メモリバウンドになります。
解決策:Kernel Fusion + Recomputation
MambaのSelective Scanアルゴリズム:
- SRAMで一括計算: 離散化(discretization)、選択的スキャン、出力計算を1つのカーネルに融合
- 中間状態を保存しない: 順伝播では(N、B、C、Δ)のみHBMに保存し、逆伝播で中間状態を再計算
- 結果: FlashAttentionと同様のIO-aware最適化
# 簡略化したSelective Scan(概念コード)
def selective_scan(x, A, B, C, delta):
"""
x: (batch, length, d_model)
Returns: (batch, length, d_model)
"""
batch, length, d = x.shape
n = A.shape[1] # state dimension
# Discretize A, B using delta
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
# Sequential scan (実際にはCUDAカーネルで並列化)
h = torch.zeros(batch, d, n, device=x.device)
ys = []
for i in range(length):
h = deltaA[:, i] * h + deltaB_x[:, i]
y = (h * C[:, i].unsqueeze(1)).sum(-1)
ys.append(y)
return torch.stack(ys, dim=1)
Mambaブロックアーキテクチャ
MambaはH3アーキテクチャからインスピレーションを得て、簡素化されたブロック設計を採用しています:
Input (x)
│
├──── Linear projection (expand) ────┐
│ │
▼ ▼
Conv1D SiLU (gate)
│ │
▼ │
SiLU │
│ │
▼ │
SSM (selective scan) │
│ │
▼ │
×─────────────────────────────────────┘
│
▼
Linear projection (reduce)
│
▼
Output
PyTorchでの実装:
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
# Input projection
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
# Conv1D
self.conv1d = nn.Conv1d(
d_inner, d_inner, d_conv,
padding=d_conv - 1, groups=d_inner
)
# SSM parameters
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, delta
self.dt_proj = nn.Linear(1, d_inner, bias=True)
# A parameter (structured)
A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
batch, length, _ = x.shape
# Dual path
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
# Conv + activation
x = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)
x = F.silu(x)
# SSM
y = self.ssm(x)
# Gate and project
y = y * F.silu(z)
return self.out_proj(y)
実験結果
言語モデリング(The Pile)
| モデル | パラメータ | Perplexity |
|---|---|---|
| Transformer++ | 1.3B | 8.94 |
| H3 | 1.3B | 9.19 |
| Hyena | 1.3B | 9.07 |
| RWKV-4 | 1.5B | 9.02 |
| Mamba | 1.4B | 8.63 |
Mambaは同規模のTransformerよりも低いperplexityを達成しました。特に2.8Bスケールでは、TransFormer++に対して約0.3 perplexityの改善を示しました。
推論速度
| シーケンス長 | Transformer | Mamba |
|---|---|---|
| 2K | 1x | 1x |
| 8K | 1x | 3倍高速 |
| 64K | 1x | 5倍高速 |
| 1M | OOM | 動作 |
Mambaはシーケンス長が長くなるほど、Transformerに対して圧倒的な速度優位性を示します。
Selective Copying & Induction Heads
論文は2つの核心的な合成タスクでSelection Mechanismの有効性を検証しています:
- Selective Copying: 長いシーケンスから特定のトークンのみを選択的にコピー → 従来SSM(S4)は失敗、Mambaは成功
- Induction Heads: 「A B ... A → B」パターンの認識 → 従来SSMは失敗、Mambaは成功
この2つのタスクはcontent-aware reasoningが必須であり、Selection Mechanismなしでは解決不可能です。
Mamba vs Transformer:いつどちらを使うべきか?
| 基準 | Transformer | Mamba |
|---|---|---|
| 長いシーケンス | △ (O(N²)) | ◎ (O(N)) |
| 推論速度 | △ (KV Cache増加) | ◎ (固定状態) |
| In-context learning | ◎ | ○ |
| 学習並列化 | ◎ (完全並列) | ○ (scan並列化) |
| エコシステム/ツール | ◎ | △ |
後続研究:Mamba-2とハイブリッド
Mamba以降、多くの進展がありました:
- Mamba-2 (Dao & Gu, 2024): SSMとAttentionの理論的な関連性を示す**State Space Duality(SSD)**フレームワークを提案
- Jamba (AI21, 2024): Transformer + Mambaのハイブリッド、256Kコンテキスト
- Zamba (Zyphra, 2024): 小規模ハイブリッドモデルでSOTAを達成
まとめ
MambaはSelection Mechanismというシンプルながら強力なアイデアでSSMの限界を克服し、O(N)の計算量でTransformerに匹敵または凌駕する性能を達成しました。ハードウェアアウェアアルゴリズム設計は、理論的効率性を実際の速度に変換するための鍵でした。
Transformerが依然として主流ですが、Mambaとハイブリッドアーキテクチャは、長いシーケンス処理が重要な領域でますます存在感を高めています。
クイズ
Q1: MambaがTransformerの核心的な限界として解決しようとしているものは?
Self-attentionのO(N²)時間・空間計算量です。シーケンス長が長くなるほど演算量とメモリが二乗で増加し、長いシーケンスの処理が非効率になります。
Q2: Selective State Spaceにおける「Selective」の意味は?
SSMのパラメータ(B、C、Δ)を入力依存にすることで、どの情報を状態に保存し、どの情報を無視するかを選択(select)できるという意味です。
Q3: 従来SSM(S4など)が言語モデリングで振るわなかった根本原因は?
Linear Time-Invariant(LTI)特性のためです。パラメータが入力と無関係に固定されているため、content-aware reasoningが不可能でした。
Q4: MambaにおけるΔ(delta)パラメータの直感的な意味は?
現在の入力をどれほど重要に反映するかを決定する「step size」です。大きなΔは現在の入力を強く反映し、小さなΔは以前の状態を維持します。
Q5: MambaのSelective Scanが従来SSMの畳み込みトリックを使用できない理由は?
入力依存パラメータのため、もはやLTIシステムではなくなるため、畳み込みに変換してFFTで並列計算するトリックが適用できません。
Q6: Mambaのハードウェア最適化がFlashAttentionと共通する核心戦略は?
IO-aware最適化です。SRAMですべての計算を融合(kernel fusion)し、中間状態をHBMに保存せず逆伝播で再計算(recomputation)します。
Q7: Selective Copyingタスクで検証しようとしている能力は?
長いシーケンスから特定のトークンのみを選択的に記憶・コピーするcontent-aware reasoning能力です。固定パラメータのSSMではこのタスクを解決できません。
Q8: Mamba-2の核心的貢献であるState Space Duality(SSD)とは?
SSMとAttentionが数学的に同一の演算の異なる表現であることを示すフレームワークです。これにより、両アプローチの利点を組み合わせるための理論的基盤を提供します。