Skip to content
Published on

Mamba: Linear-Time Sequence Modeling with Selective State Spaces 論文分析

Authors
  • Name
    Twitter
Mamba SSM

論文概要

  • タイトル: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
  • 著者: Albert Gu, Tri Dao (Carnegie Mellon University, Princeton University)
  • 発表: 2023年12月 (arXiv: 2312.00752)、ICML 2024 採択
  • コード: github.com/state-spaces/mamba

動機:Transformerの限界

Transformerはシーケンスモデリングの王座に君臨していますが、根本的な問題を抱えています:

  • O(N²) 計算量: Self-attentionはシーケンス長に対して二次計算量
  • 長いシーケンスの処理限界: 128K以上のコンテキストでメモリ・演算量が爆発的に増加
  • 推論の非効率性: KV Cacheがシーケンス長に比例して増加

State Space Model(SSM)は理論的にO(N)の計算量を持ちますが、従来のSSM(S4、H3など)は入力に依存しない固定パラメータのため、言語モデリングではTransformerに及びませんでした。

核心アイデア:Selective State Spaces

従来SSMの問題点

従来のSSMは**Linear Time-Invariant(LTI)**システムです:

h(t)=Ah(t)+Bx(t)h'(t) = Ah(t) + Bx(t) y(t)=Ch(t)y(t) = Ch(t)

ここでA、B、Cは固定された行列です。入力が「hello」でも「world」でも同じ方法で処理します。これはcontent-aware reasoningが不可能であることを意味します。

Mambaの解決策:Selection Mechanism

Mambaの核心的イノベーションは、B、C、Δを入力依存にすることです:

# 従来SSM:固定パラメータ
B = nn.Parameter(torch.randn(N))  # 学習されるが入力に依存しない

# Mamba:入力依存パラメータ
B = nn.Linear(D, N)(x)      # xに応じてBが変化
C = nn.Linear(D, N)(x)      # xに応じてCが変化
delta = nn.Linear(D, D)(x)  # xに応じてstep sizeも変化

これがSelective State Spaceの意味です。入力に応じて、どの情報を状態に保存し、どの情報を無視するかを選択します。

直感的理解

Selection Mechanismを直感的に理解すると:

  • Δ(delta): 「このトークンをどれほど重要視するか」 — 大きなΔは現在の入力を強く反映、小さなΔは以前の状態を維持
  • B: 「この入力のどの部分を状態に記録するか」
  • C: 「状態のどの部分を出力として取り出すか」

これはTransformerのattentionが果たす役割と類似していますが、**O(N)**の計算量で実現します。

ハードウェア最適化:Selective Scanアルゴリズム

Selective SSMは入力依存であるため、従来SSMの効率的な畳み込み(convolution)トリックを使用できません。Mambaはこの問題をハードウェアアウェアアルゴリズムで解決します。

問題:HBM ↔ SRAM ボトルネック

GPUのメモリ階層:

┌────────────────────┐
SRAM (20MB)      │  ← 非常に高速、非常に小さい
├────────────────────┤
HBM (80GB)       │  ← 低速、大容量
└────────────────────┘

ナイーブな実装では中間状態をHBMに繰り返し保存・読み込みするため、メモリバウンドになります。

解決策:Kernel Fusion + Recomputation

MambaのSelective Scanアルゴリズム:

  1. SRAMで一括計算: 離散化(discretization)、選択的スキャン、出力計算を1つのカーネルに融合
  2. 中間状態を保存しない: 順伝播では(N、B、C、Δ)のみHBMに保存し、逆伝播で中間状態を再計算
  3. 結果: FlashAttentionと同様のIO-aware最適化
# 簡略化したSelective Scan(概念コード)
def selective_scan(x, A, B, C, delta):
    """
    x: (batch, length, d_model)
    Returns: (batch, length, d_model)
    """
    batch, length, d = x.shape
    n = A.shape[1]  # state dimension

    # Discretize A, B using delta
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
    deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)

    # Sequential scan (実際にはCUDAカーネルで並列化)
    h = torch.zeros(batch, d, n, device=x.device)
    ys = []
    for i in range(length):
        h = deltaA[:, i] * h + deltaB_x[:, i]
        y = (h * C[:, i].unsqueeze(1)).sum(-1)
        ys.append(y)

    return torch.stack(ys, dim=1)

Mambaブロックアーキテクチャ

MambaはH3アーキテクチャからインスピレーションを得て、簡素化されたブロック設計を採用しています:

Input (x)
    ├──── Linear projection (expand) ────┐
    │                                     │
    ▼                                     ▼
  Conv1D                              SiLU (gate)
    │                                     │
    ▼                                     │
  SiLU    │                                     │
    ▼                                     │
  SSM (selective scan)    │                                     │
    ▼                                     │
    ×─────────────────────────────────────┘
  Linear projection (reduce)
  Output

PyTorchでの実装:

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        d_inner = d_model * expand

        # Input projection
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        # Conv1D
        self.conv1d = nn.Conv1d(
            d_inner, d_inner, d_conv,
            padding=d_conv - 1, groups=d_inner
        )

        # SSM parameters
        self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False)  # B, C, delta
        self.dt_proj = nn.Linear(1, d_inner, bias=True)

        # A parameter (structured)
        A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))

        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        batch, length, _ = x.shape

        # Dual path
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Conv + activation
        x = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)
        x = F.silu(x)

        # SSM
        y = self.ssm(x)

        # Gate and project
        y = y * F.silu(z)
        return self.out_proj(y)

実験結果

言語モデリング(The Pile)

モデルパラメータPerplexity
Transformer++1.3B8.94
H31.3B9.19
Hyena1.3B9.07
RWKV-41.5B9.02
Mamba1.4B8.63

Mambaは同規模のTransformerよりも低いperplexityを達成しました。特に2.8Bスケールでは、TransFormer++に対して約0.3 perplexityの改善を示しました。

推論速度

シーケンス長TransformerMamba
2K1x1x
8K1x3倍高速
64K1x5倍高速
1MOOM動作

Mambaはシーケンス長が長くなるほど、Transformerに対して圧倒的な速度優位性を示します。

Selective Copying & Induction Heads

論文は2つの核心的な合成タスクでSelection Mechanismの有効性を検証しています:

  1. Selective Copying: 長いシーケンスから特定のトークンのみを選択的にコピー → 従来SSM(S4)は失敗、Mambaは成功
  2. Induction Heads: 「A B ... A → B」パターンの認識 → 従来SSMは失敗、Mambaは成功

この2つのタスクはcontent-aware reasoningが必須であり、Selection Mechanismなしでは解決不可能です。

Mamba vs Transformer:いつどちらを使うべきか?

基準TransformerMamba
長いシーケンス△ (O(N²))◎ (O(N))
推論速度△ (KV Cache増加)◎ (固定状態)
In-context learning
学習並列化◎ (完全並列)○ (scan並列化)
エコシステム/ツール

後続研究:Mamba-2とハイブリッド

Mamba以降、多くの進展がありました:

  • Mamba-2 (Dao & Gu, 2024): SSMとAttentionの理論的な関連性を示す**State Space Duality(SSD)**フレームワークを提案
  • Jamba (AI21, 2024): Transformer + Mambaのハイブリッド、256Kコンテキスト
  • Zamba (Zyphra, 2024): 小規模ハイブリッドモデルでSOTAを達成

まとめ

MambaはSelection Mechanismというシンプルながら強力なアイデアでSSMの限界を克服し、O(N)の計算量でTransformerに匹敵または凌駕する性能を達成しました。ハードウェアアウェアアルゴリズム設計は、理論的効率性を実際の速度に変換するための鍵でした。

Transformerが依然として主流ですが、Mambaとハイブリッドアーキテクチャは、長いシーケンス処理が重要な領域でますます存在感を高めています。

クイズ

Q1: MambaがTransformerの核心的な限界として解決しようとしているものは? Self-attentionのO(N²)時間・空間計算量です。シーケンス長が長くなるほど演算量とメモリが二乗で増加し、長いシーケンスの処理が非効率になります。

Q2: Selective State Spaceにおける「Selective」の意味は? SSMのパラメータ(B、C、Δ)を入力依存にすることで、どの情報を状態に保存し、どの情報を無視するかを選択(select)できるという意味です。

Q3: 従来SSM(S4など)が言語モデリングで振るわなかった根本原因は? Linear Time-Invariant(LTI)特性のためです。パラメータが入力と無関係に固定されているため、content-aware reasoningが不可能でした。

Q4: MambaにおけるΔ(delta)パラメータの直感的な意味は? 現在の入力をどれほど重要に反映するかを決定する「step size」です。大きなΔは現在の入力を強く反映し、小さなΔは以前の状態を維持します。

Q5: MambaのSelective Scanが従来SSMの畳み込みトリックを使用できない理由は? 入力依存パラメータのため、もはやLTIシステムではなくなるため、畳み込みに変換してFFTで並列計算するトリックが適用できません。

Q6: Mambaのハードウェア最適化がFlashAttentionと共通する核心戦略は? IO-aware最適化です。SRAMですべての計算を融合(kernel fusion)し、中間状態をHBMに保存せず逆伝播で再計算(recomputation)します。

Q7: Selective Copyingタスクで検証しようとしている能力は? 長いシーケンスから特定のトークンのみを選択的に記憶・コピーするcontent-aware reasoning能力です。固定パラメータのSSMではこのタスクを解決できません。

Q8: Mamba-2の核心的貢献であるState Space Duality(SSD)とは? SSMとAttentionが数学的に同一の演算の異なる表現であることを示すフレームワークです。これにより、両アプローチの利点を組み合わせるための理論的基盤を提供します。