- Authors
- Name
- 1. 序論:Transformerの限界とSSMの台頭
- 2. 背景:Structured State Space Models(S4)
- 3. Mamba:Selective State Space Models
- 4. Mamba-2:State Space Duality
- 5. 実践:Mambaモデルを使う
- 6. Mamba vs Transformer比較
- 7. 限界と展望
- 8. クイズ
1. 序論:Transformerの限界とSSMの台頭
Transformerアーキテクチャは2017年の登場以来、NLP、Vision、Audioなどほぼすべてのシーケンスモデリング分野を支配してきた。しかし、根本的な限界がある:
- O(N二乗)の計算量:Self-Attentionの時間・空間計算量がシーケンス長に対して二次的に増加
- 推論の非効率性:各トークン生成時に全KV Cacheを参照する必要がある
- 長いシーケンス処理の困難:128K以上のコンテキストでメモリボトルネック
State Space Model(SSM)は、これらの限界を克服するために提案された代替手法である。特にMamba(Gu & Dao, ICLR 2024)はSSMに選択的(Selective)メカニズムを導入し、Transformerに匹敵する性能をO(N)の計算量で達成した。
2. 背景:Structured State Space Models(S4)
2.1 連続時間SSM
SSMは連続時間システムを離散化したものである:
- :hidden state
- :入力
- :出力
- , ,
2.2 離散化(Zero-Order Hold)
連続システムを離散時間に変換:
離散化された再帰関係:
2.3 S4の核心:HiPPO初期化
S4(Structured State Spaces for Sequence Modeling)の核心的貢献は、**HiPPO(High-order Polynomial Projection Operator)**行列でAを初期化したことである:
import torch
import numpy as np
def make_hippo(N):
"""HiPPO-LegS行列を生成"""
P = np.sqrt(1 + 2 * np.arange(N))
A = np.zeros((N, N))
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = P[i] * P[j]
elif i == j:
A[i, j] = i + 1
# i < j: 0
return -A # 安定性のため負数
2.4 S4の限界
S4は入力に依存しない固定パラメータ(A, B, C, Δ)を使用する。これは2つの問題を引き起こす:
- Content-based reasoningの不可:入力内容に応じて異なる処理が必要なタスク(例:selective copying、induction heads)で性能低下
- Attentionの利点の喪失:Transformerのcontent-awareマッチング能力の損失
3. Mamba:Selective State Space Models
3.1 核心アイデア:Selection Mechanism
Mambaの核心的革新は、SSMパラメータを入力依存にしたことである:
S4: (A, B, C, Δ) = 固定パラメータ
Mamba: (B, C, Δ) = f(input) ← 入力依存!
具体的には:
これにより、モデルがどの情報を記憶し、どの情報を忘れるかを入力に応じて動的に決定できる。
3.2 Selectionの直観
(step size)が鍵となる:
- が大きい場合: → 以前のstateを維持(入力を無視)
- が小さい場合:現在の入力により集中(stateを更新)
これはゲーティングメカニズムと類似している:
LSTMのforget gate ≈ MambaのΔ
3.3 Hardware-Awareアルゴリズム
Selectionを導入するとパラメータが入力依存になるため、convolution trickが使用できなくなる。これは学習時にO(N二乗)計算量に回帰するリスクがある。
Mambaはこれをkernel fusion + recomputationで解決する:
# MambaのSelective Scan(簡略版)
def selective_scan(x, delta, A, B, C):
"""
x: (B, L, D) - 入力
delta: (B, L, D) - step size(入力依存)
A: (D, N) - state matrix
B: (B, L, N) - input matrix(入力依存)
C: (B, L, N) - output matrix(入力依存)
"""
B_batch, L, D = x.shape
N = A.shape[1]
# 離散化
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
# Sequential scan(学習時はparallel scanを使用)
h = torch.zeros(B_batch, D, N, device=x.device)
ys = []
for t in range(L):
h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]
y = (h * C[:, t, None, :]).sum(-1) # (B, D)
ys.append(y)
return torch.stack(ys, dim=1) # (B, L, D)
実際の実装では、CUDAカーネル内でSRAMにすべての中間状態を保持し、HBMアクセスを最小化する(FlashAttentionと同様のIO-awareアプローチ)。
3.4 Mamba Blockアーキテクチャ
Input
│
├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐
│ │
└──── Linear(D → ED) ──── SiLU ──────────────────────────── × ───────┘
│
Linear(ED → D)
│
Output
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,
padding=d_conv-1, groups=d_inner)
# SSMパラメータ
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, Δ
self.dt_proj = nn.Linear(1, d_inner, bias=True)
A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
# x: (B, L, D)
xz = self.in_proj(x) # (B, L, 2*ED)
x, z = xz.chunk(2, dim=-1)
# Conv + SiLU
x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]
x = x.transpose(1, 2)
x = F.silu(x)
# SSM
A = -torch.exp(self.A_log)
B_C_dt = self.x_proj(x)
# ... selective scan ...
# Gate
y = y * F.silu(z)
return self.out_proj(y)
4. Mamba-2:State Space Duality
4.1 SSD(State Space Dual)モデル
Mamba-2(Dao & Gu, ICML 2024)は、SSMとAttentionの間の数学的二重性(Duality)を発見した:
- SSMの観点:再帰的state更新 → O(N)推論
- Attentionの観点:特殊な構造の行列積 → 並列学習
SSM recurrence: h_t = A_t h_{t-1} + B_t x_t
y_t = C_t h_t
↕ Dual ↕
Structured Attention: Y = (T ⊙ M) X
where T = lower-triangular causal mask
M = structured (semi-separable) matrix
4.2 性能比較
| モデル | パラメータ | Pile (ppl) | 学習FLOPS/s | 推論 (tok/s) |
|---|---|---|---|---|
| Transformer++ | 2.7B | 6.7 | 1.0x | 1.0x |
| Mamba-1 | 2.8B | 6.2 | 1.3x | 5.2x |
| Mamba-2 | 2.7B | 6.1 | 2.1x | 5.5x |
主な改善点:
- 学習速度2倍向上(structured matrix multiplicationの活用)
- より大きなstate sizeが可能(N=64 → N=256)
- Multi-head構造の導入
5. 実践:Mambaモデルを使う
5.1 インストールと推論
pip install mamba-ssm causal-conv1d>=1.4.0
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
# Mamba-2.8Bのロード
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = model.cuda().half()
# 推論
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=100, temperature=0.7)
print(tokenizer.decode(output[0]))
5.2 Mamba-2の使用
from mamba_ssm import Mamba2
# Mamba-2レイヤーの単独使用
layer = Mamba2(
d_model=2048,
d_state=128, # Mamba-2はより大きなstateが可能
d_conv=4,
expand=2,
headdim=64, # Multi-head
).cuda()
x = torch.randn(1, 1024, 2048).cuda() # (batch, seq_len, d_model)
y = layer(x) # (1, 1024, 2048)
6. Mamba vs Transformer比較
| 特性 | Transformer | Mamba |
|---|---|---|
| 学習計算量 | O(N二乗) | O(N) |
| 推論計算量 | O(N) per token(KV Cache) | O(1) per token |
| メモリ(推論) | O(N) KV Cache | O(1) 固定state |
| In-context learning | 強い | やや弱い(改善中) |
| 長いシーケンス | ボトルネック | 効率的 |
| ハードウェア活用 | 並列化に最適 | 再帰ボトルネック(Mamba-2で改善) |
7. 限界と展望
7.1 現在の限界
- In-context learning:固定サイズのstateに無限のコンテキスト情報を圧縮する際の情報損失
- Retrievalタスク:正確なトークン検索が必要な場合はAttentionが優位
- 学習安定性:長いシーケンスでのgradient不安定問題
7.2 ハイブリッドアプローチ
最近のトレンドはMamba + Attentionハイブリッド:
- Jamba(AI21):Mambaレイヤー + Attentionレイヤーの混合
- Zamba(Zyphra):Mambaベース + 少数のShared Attentionレイヤー
8. クイズ
Q1. MambaがS4と比べて核心的に改善した点は?
SSMパラメータ(B, C, Δ)を入力依存的にすることで、content-based reasoningを可能にしたこと。S4は固定パラメータであったため、入力内容に応じた選択的処理ができなかった。
Q2. MambaにおけるΔ(delta)の役割は?
Step sizeとして、ゲーティングメカニズムの役割を果たす。Δが大きい場合は以前のstateを維持(入力を無視)、小さい場合は現在の入力に集中(stateを更新)。LSTMのforget gateに類似。
Q3. Selection導入後にconvolution trickが使用できない理由は?
Convolution trickは**時不変(time-invariant)**システムでのみ動作する。Selectionによりパラメータが入力依存になると時変(time-varying)となり、グローバルコンボリューションが不可能になる。
Q4. MambaのHardware-Awareアルゴリズムの核心は?
CUDAカーネル内でSRAMに中間stateを保持し、HBMアクセスを最小化すること。FlashAttentionと同様のIO-awareアプローチで、memory-bound演算を最適化する。
Q5. Mamba-2のState Space Dualityとは?
SSMの再帰的state更新とstructured attention行列積が数学的に同一であるということ。これにより、学習時は並列化された行列積を、推論時は効率的な再帰をそれぞれ活用できる。
Q6. Mambaに対してTransformerが優位なタスクは?
In-context learningと正確なトークン検索(retrieval)タスク。Attentionはシーケンス内の任意の位置を直接参照できるが、Mambaは固定サイズのstateに情報を圧縮する必要がある。
Q7. Jamba、Zambaのようなハイブリッドモデルの設計原理は?
MambaレイヤーのO(N)効率性 + 少数のAttentionレイヤーの正確な検索能力を組み合わせること。大部分のレイヤーをMambaで構成し、重要な位置にのみAttentionを配置する。