Skip to content
Published on

Mamba論文レビュー:Selective State Space ModelsでTransformerを超える

Authors
  • Name
    Twitter

1. 序論:Transformerの限界とSSMの台頭

Transformerアーキテクチャは2017年の登場以来、NLP、Vision、Audioなどほぼすべてのシーケンスモデリング分野を支配してきた。しかし、根本的な限界がある:

  • O(N二乗)の計算量:Self-Attentionの時間・空間計算量がシーケンス長に対して二次的に増加
  • 推論の非効率性:各トークン生成時に全KV Cacheを参照する必要がある
  • 長いシーケンス処理の困難:128K以上のコンテキストでメモリボトルネック

State Space Model(SSM)は、これらの限界を克服するために提案された代替手法である。特にMamba(Gu & Dao, ICLR 2024)はSSMに選択的(Selective)メカニズムを導入し、Transformerに匹敵する性能をO(N)の計算量で達成した。

2. 背景:Structured State Space Models(S4)

2.1 連続時間SSM

SSMは連続時間システムを離散化したものである:

h(t)=Ah(t)+Bx(t)h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t) y(t)=Ch(t)y(t) = \mathbf{C}h(t)
  • h(t)RNh(t) \in \mathbb{R}^N:hidden state
  • x(t)Rx(t) \in \mathbb{R}:入力
  • y(t)Ry(t) \in \mathbb{R}:出力
  • ARN×N\mathbf{A} \in \mathbb{R}^{N \times N}, BRN×1\mathbf{B} \in \mathbb{R}^{N \times 1}, CR1×N\mathbf{C} \in \mathbb{R}^{1 \times N}

2.2 離散化(Zero-Order Hold)

連続システムを離散時間に変換:

A=exp(ΔA)\overline{\mathbf{A}} = \exp(\Delta \mathbf{A}) B=(ΔA)1(exp(ΔA)I)ΔB\overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}

離散化された再帰関係:

ht=Aht1+Bxth_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_t yt=Chty_t = \mathbf{C} h_t

2.3 S4の核心:HiPPO初期化

S4(Structured State Spaces for Sequence Modeling)の核心的貢献は、**HiPPO(High-order Polynomial Projection Operator)**行列でAを初期化したことである:

import torch
import numpy as np

def make_hippo(N):
    """HiPPO-LegS行列を生成"""
    P = np.sqrt(1 + 2 * np.arange(N))
    A = np.zeros((N, N))
    for i in range(N):
        for j in range(N):
            if i > j:
                A[i, j] = P[i] * P[j]
            elif i == j:
                A[i, j] = i + 1
            # i < j: 0
    return -A  # 安定性のため負数

2.4 S4の限界

S4は入力に依存しない固定パラメータ(A, B, C, Δ)を使用する。これは2つの問題を引き起こす:

  1. Content-based reasoningの不可:入力内容に応じて異なる処理が必要なタスク(例:selective copying、induction heads)で性能低下
  2. Attentionの利点の喪失:Transformerのcontent-awareマッチング能力の損失

3. Mamba:Selective State Space Models

3.1 核心アイデア:Selection Mechanism

Mambaの核心的革新は、SSMパラメータを入力依存にしたことである:

S4:    (A, B, C, Δ) = 固定パラメータ
Mamba: (B, C, Δ) = f(input)  ← 入力依存!

具体的には:

Bt=LinearB(xt),Ct=LinearC(xt),Δt=softplus(LinearΔ(xt))\mathbf{B}_t = \text{Linear}_B(x_t), \quad \mathbf{C}_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))

これにより、モデルがどの情報を記憶し、どの情報を忘れるかを入力に応じて動的に決定できる。

3.2 Selectionの直観

Δt\Delta_t(step size)が鍵となる:

  • Δt\Delta_tが大きい場合AI\overline{\mathbf{A}} \approx \mathbf{I} → 以前のstateを維持(入力を無視)
  • Δt\Delta_tが小さい場合:現在の入力により集中(stateを更新)

これはゲーティングメカニズムと類似している:

LSTMのforget gate ≈ MambaのΔ

3.3 Hardware-Awareアルゴリズム

Selectionを導入するとパラメータが入力依存になるため、convolution trickが使用できなくなる。これは学習時にO(N二乗)計算量に回帰するリスクがある。

Mambaはこれをkernel fusion + recomputationで解決する:

# MambaのSelective Scan(簡略版)
def selective_scan(x, delta, A, B, C):
    """
    x: (B, L, D)  - 入力
    delta: (B, L, D) - step size(入力依存)
    A: (D, N)      - state matrix
    B: (B, L, N)   - input matrix(入力依存)
    C: (B, L, N)   - output matrix(入力依存)
    """
    B_batch, L, D = x.shape
    N = A.shape[1]

    # 離散化
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
    deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, D, N)

    # Sequential scan(学習時はparallel scanを使用)
    h = torch.zeros(B_batch, D, N, device=x.device)
    ys = []
    for t in range(L):
        h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]
        y = (h * C[:, t, None, :]).sum(-1)  # (B, D)
        ys.append(y)

    return torch.stack(ys, dim=1)  # (B, L, D)

実際の実装では、CUDAカーネル内でSRAMにすべての中間状態を保持し、HBMアクセスを最小化する(FlashAttentionと同様のIO-awareアプローチ)。

3.4 Mamba Blockアーキテクチャ

Input
  ├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐
  │                                                                     │
  └──── Linear(DED) ──── SiLU ──────────────────────────── × ───────┘
                                                          Linear(EDD)
                                                            Output
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        d_inner = d_model * expand

        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,
                                padding=d_conv-1, groups=d_inner)

        # SSMパラメータ
        self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False)  # B, C, Δ
        self.dt_proj = nn.Linear(1, d_inner, bias=True)

        A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(d_inner))

        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        # x: (B, L, D)
        xz = self.in_proj(x)  # (B, L, 2*ED)
        x, z = xz.chunk(2, dim=-1)

        # Conv + SiLU
        x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]
        x = x.transpose(1, 2)
        x = F.silu(x)

        # SSM
        A = -torch.exp(self.A_log)
        B_C_dt = self.x_proj(x)
        # ... selective scan ...

        # Gate
        y = y * F.silu(z)
        return self.out_proj(y)

4. Mamba-2:State Space Duality

4.1 SSD(State Space Dual)モデル

Mamba-2(Dao & Gu, ICML 2024)は、SSMとAttentionの間の数学的二重性(Duality)を発見した:

  • SSMの観点:再帰的state更新 → O(N)推論
  • Attentionの観点:特殊な構造の行列積 → 並列学習
SSM recurrence: h_t = A_t h_{t-1} + B_t x_t
                 y_t = C_t h_t

Dual
Structured Attention: Y = (TM) X
  where T = lower-triangular causal mask
        M = structured (semi-separable) matrix

4.2 性能比較

モデルパラメータPile (ppl)学習FLOPS/s推論 (tok/s)
Transformer++2.7B6.71.0x1.0x
Mamba-12.8B6.21.3x5.2x
Mamba-22.7B6.12.1x5.5x

主な改善点:

  • 学習速度2倍向上(structured matrix multiplicationの活用)
  • より大きなstate sizeが可能(N=64 → N=256)
  • Multi-head構造の導入

5. 実践:Mambaモデルを使う

5.1 インストールと推論

pip install mamba-ssm causal-conv1d>=1.4.0
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer

# Mamba-2.8Bのロード
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = model.cuda().half()

# 推論
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=100, temperature=0.7)
print(tokenizer.decode(output[0]))

5.2 Mamba-2の使用

from mamba_ssm import Mamba2

# Mamba-2レイヤーの単独使用
layer = Mamba2(
    d_model=2048,
    d_state=128,    # Mamba-2はより大きなstateが可能
    d_conv=4,
    expand=2,
    headdim=64,     # Multi-head
).cuda()

x = torch.randn(1, 1024, 2048).cuda()  # (batch, seq_len, d_model)
y = layer(x)  # (1, 1024, 2048)

6. Mamba vs Transformer比較

特性TransformerMamba
学習計算量O(N二乗)O(N)
推論計算量O(N) per token(KV Cache)O(1) per token
メモリ(推論)O(N) KV CacheO(1) 固定state
In-context learning強いやや弱い(改善中)
長いシーケンスボトルネック効率的
ハードウェア活用並列化に最適再帰ボトルネック(Mamba-2で改善)

7. 限界と展望

7.1 現在の限界

  • In-context learning:固定サイズのstateに無限のコンテキスト情報を圧縮する際の情報損失
  • Retrievalタスク:正確なトークン検索が必要な場合はAttentionが優位
  • 学習安定性:長いシーケンスでのgradient不安定問題

7.2 ハイブリッドアプローチ

最近のトレンドはMamba + Attentionハイブリッド

  • Jamba(AI21):Mambaレイヤー + Attentionレイヤーの混合
  • Zamba(Zyphra):Mambaベース + 少数のShared Attentionレイヤー

8. クイズ

Q1. MambaがS4と比べて核心的に改善した点は?

SSMパラメータ(B, C, Δ)を入力依存的にすることで、content-based reasoningを可能にしたこと。S4は固定パラメータであったため、入力内容に応じた選択的処理ができなかった。

Q2. MambaにおけるΔ(delta)の役割は?

Step sizeとして、ゲーティングメカニズムの役割を果たす。Δが大きい場合は以前のstateを維持(入力を無視)、小さい場合は現在の入力に集中(stateを更新)。LSTMのforget gateに類似。

Q3. Selection導入後にconvolution trickが使用できない理由は?

Convolution trickは**時不変(time-invariant)**システムでのみ動作する。Selectionによりパラメータが入力依存になると時変(time-varying)となり、グローバルコンボリューションが不可能になる。

Q4. MambaのHardware-Awareアルゴリズムの核心は?

CUDAカーネル内でSRAMに中間stateを保持し、HBMアクセスを最小化すること。FlashAttentionと同様のIO-awareアプローチで、memory-bound演算を最適化する。

Q5. Mamba-2のState Space Dualityとは?

SSMの再帰的state更新とstructured attention行列積が数学的に同一であるということ。これにより、学習時は並列化された行列積を、推論時は効率的な再帰をそれぞれ活用できる。

Q6. Mambaに対してTransformerが優位なタスクは?

In-context learningと正確なトークン検索(retrieval)タスク。Attentionはシーケンス内の任意の位置を直接参照できるが、Mambaは固定サイズのstateに情報を圧縮する必要がある。

Q7. Jamba、Zambaのようなハイブリッドモデルの設計原理は?

MambaレイヤーのO(N)効率性 + 少数のAttentionレイヤーの正確な検索能力を組み合わせること。大部分のレイヤーをMambaで構成し、重要な位置にのみAttentionを配置する。