- Published on
MambaとState Space Model論文の詳細な分析:オプションのSSMからMamba-2までのトランスフォーマー代替アーキテクチャ
- Authors
- Name

##入り
Transformerアーキテクチャは、2017年の登場以来、自然言語処理、コンピュータビジョン、オーディオ、時系列など、ほぼすべてのシーケンスモデリング領域を支配してきました。しかし、Self-Attentionの二次時間複雑度はシーケンス長が増加するほど演算量とメモリ使用量が爆発的に増える根本的な限界を持つ。 128K以上のコンテキストウィンドウを扱う現代LLMでは、この問題はさらに深刻になっています。
この背景では、State Space Model(SSM)は、線形時間複雑度でシーケンスを処理するための代替として研究されています。 2021年、Albert Guが提案したStructured State Spaces(S4)は、長いシーケンスベンチマークでTransformerを凌駕する性能を示し、SSM研究のルネサンスを開いた。しかし、S4を含む初期のSSMは、Linear Time-Invariant(LTI)システムという制約のため、入力内容に応じて動的に情報を選択する能力が不足し、言語モデリングでTransformerに及ばなかった。
2023年12月、Albert GuとTri Daoが発表したMamba論文はこの限界を正面に突破した。 SSMのパラメータを入力に依存的にする選択的メカニズム(Selection Mechanism)を導入し、GPUハードウェアに最適化されたスキャンアルゴリズムで効率性まで確保した。その後、2024年5月に発表されたMamba-2は、SSMとAttentionが数学的に同じ系列に属するというState Space Duality(SSD)を証明し、2-8倍速いアルゴリズムを提示した。
この記事では、SSMの数学的基礎(S4)からMambaの選択的メカニズム、Mamba-2のSSDフレームワーク、Transformerとの定量的比較、ビジョン/オーディオ/時系列など、さまざまなドメインの適用、および実践の実装と運営における注意事項までを詳しく説明します。
SSM基礎理論(S4)
連続時間 State Space Model
State Space Modelは、制御理論に由来する連続時間動的システムです。入力信号 を隠し状態 を経て出力 に変換する過程を次の微分方程式で記述する。
ここで、 は状態遷移行列、 は入力投影行列、 は出力投影行列、 N $は状態ディメンションでSSMのメモリ容量を決定します。
離散シーケンスを処理するには、連続システムを離散化する必要があります。ステップサイズに対してZero-Order Hold(ZOH)二酸化を適用すると、次のようになります。
二酸化SSMは2つの方法で計算できます。まず、再帰的な形でシーケンスを順次処理します。
第二に、シーケンス全体を一度に処理する合成積の形があります。カーネル を事前計算すれば、FFT を活用した 演算が可能である。
S4のHiPPOの初期化
S4(Structured State Spaces for Sequence Modeling、Gu et al。、2021)の重要な貢献は、状態行列の初期化戦略です。ランダムに初期化されたは長期依存性を捉えることができず、傾き消失/爆発問題を経験する。 S4はHiPPO(High-order Polynomial Projection Operators)フレームワークを活用してを特殊な構造に初期化する。
HiPPO-LegS行列は,過去の入力をルジャンドル多項式基底に連続的に近似するように設計した。この構造の重要な性質は、すべての時間スケールの情報を均等に保存することです。```python import torch import torch.nn as nn import numpy as np
def make_hippo_matrix(N: int) -> torch.Tensor: """HiPPO-LegS 행렬 생성.
이 행렬은 과거 입력을 르장드르 다항식 기저로 최적 근사하도록 설계되어 장기 의존성 포착에 유리하다.
Args: N: 상태 차원 (State dimension)
Returns: A: (N, N) HiPPO 행렬 """ P = np.sqrt(2 * np.arange(N) + 1) A = np.zeros((N, N)) for i in range(N): for j in range(N): if i > j: A[i, j] = P[i] * P[j] elif i == j: A[i, j] = i + 1
i < j: 0
A = -A # 안정성을 위해 음수 return torch.tensor(A, dtype=torch.float32)
class S4Layer(nn.Module): """S4(Structured State Spaces) 레이어의 핵심 구현.
HiPPO 초기화된 A 행렬과 합성곱 기반 병렬 계산을 결합한다. """
def init(self, d_model: int, state_dim: int = 64, seq_len: int = 1024): super().init() self.d_model = d_model self.state_dim = state_dim self.seq_len = seq_len
HiPPO 초기화
A = make_hippo_matrix(state_dim) self.A_log = nn.Parameter(torch.log(torch.clamp(-A.diagonal(), min=1e-4))) self.B = nn.Parameter(torch.randn(state_dim, 1) * 0.01) self.C = nn.Parameter(torch.randn(1, state_dim) * 0.01) self.log_dt = nn.Parameter(torch.rand(d_model).uniform_(-4, -1))
skip connection
self.D = nn.Parameter(torch.ones(d_model))
def _compute_kernel(self, L: int) -> torch.Tensor: """SSM 합성곱 커널을 사전 계산한다.""" dt = torch.exp(self.log_dt) # (d_model,) A = -torch.exp(self.A_log) # (state_dim,) 대각 성분
ZOH 이산화 (대각 A 가정)
dtA = dt.unsqueeze(-1) * A.unsqueeze(0) # (d_model, state_dim) A_bar = torch.exp(dtA)
커널 계산: K[k] = C @ A_bar^k @ B_bar
효율적인 Vandermonde 곱으로 O(N*L) 계산
powers = torch.arange(L, device=A.device).float()
A_bar^k = exp(k * dtA)
kernel = torch.einsum( 'dn,dn,nl->dl', self.C.squeeze(0).expand(self.d_model, -1), dt.unsqueeze(-1) * self.B.squeeze(-1).unsqueeze(0), torch.exp(dtA.unsqueeze(-1) * powers.unsqueeze(0).unsqueeze(0)) .squeeze() .T )
return kernel # (d_model, L)
def forward(self, x: torch.Tensor) -> torch.Tensor: """합성곱 모드의 forward pass.
Args: x: (batch, seq_len, d_model) 입력 시퀀스
Returns: y: (batch, seq_len, d_model) 출력 시퀀스 """ B, L, D = x.shape kernel = self._compute_kernel(L) # (d_model, L)
FFT 기반 합성곱
x_perm = x.permute(0, 2, 1) # (B, D, L) k_f = torch.fft.rfft(kernel, n=2 * L) # zero-pad x_f = torch.fft.rfft(x_perm, n=2 * L) y = torch.fft.irfft(x_f * k_f, n=2 * L)[..., :L]
skip connection
y = y + self.D.unsqueeze(0).unsqueeze(-1) * x_perm return y.permute(0, 2, 1)
## Mambaのオプションメカニズム
### LTIからSelectionへの切り替え
Mambaのコア洞察はシンプルだが強力です。 SSM パラメータ $B$、$C$、$\Delta$ を入力の関数にしてシステムを時変(time-varying)に切り替えるのです。これにより、モデルは各時点でどの情報を状態に記録し、どの情報を無視するかを動的に決定できます。
既存のLTI SSMでは、すべての入力トークンが同じダイナミクスで処理されます。 「マットのキャット・ザ・ザット」という文章では、官事「ザ」とコア名詞「キャット」が同じように状態に反映されるので、重要な情報を選択的に強調したり、不要な情報を除外したりすることは不可能です。 TransformerのAttentionは本質的にこのオプションの情報処理を実行します。
式で表現すると、MambaのオプションのSSMは次のようになります。
$$B_t = \text{Linear}_B(x_t), \quad C_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))$$
$$\bar{A}_t = \exp(\Delta_t A), \quad \bar{B}_t = \Delta_t B_t$$
$$h_t = \bar{A}_t h_{t-1} + \bar{B}_t x_t$$
$$y_t = C_t h_t$$
$\Delta_t$の役割は特に重要です。 $\Delta_t$ が大きいと $\bar{A}_t = \exp(\Delta_t A)$ が 0 に近づき、前の状態がリセットされ、新しい入力 $x_t$ が状態を支配します。逆に、$ \ Delta_t $が小さいと、以前の状態が保存され、現在の入力の影響が減少します。これはRNNのゲートメカニズムと似ていますが、連続時間システムの数学的フレームワーク内で自然に導き出されたものです。
### Selectionが解決する合成タスク
Mambaの論文は、Selection Mechanismの効果を2つの合成タスクで実証しています。
まず、Selective Copyingタスクです。入力シーケンス内の特定のトークンのみを選択的にコピーする必要があるという問題で、LTI SSMは「どのトークンをコピーするか」を入力内容によって判断できず、失敗します。 Mambaは$B_t$によって重要なトークンの情報を状態に強く記録し、不要なトークンは無視する。
第二に、インダクションヘッドタスクです。 「A B ... A」というパターンが現れたときに次に「B」を予測しなければならない問題で、TransformerはAttentionメカニズムで自然に解決するが、LTI SSMはこのパターンマッチングが不可能である。 Mambaの選択的メカニズムは、「A」が現れたときに以前の「A B」パターンの記憶を活性化して「B」を予測することができる。```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelectiveSSM(nn.Module):
"""Mamba의 선택적 State Space Model 핵심 구현.
B, C, delta를 입력의 함수로 만들어
content-aware한 시퀀스 처리를 가능하게 한다.
"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# A 행렬: 대각 구조, 로그 공간에서 파라미터화
A = torch.arange(1, d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(d_model, -1))
# 입력 의존적 B, C 투영
self.B_proj = nn.Linear(d_model, d_state, bias=False)
self.C_proj = nn.Linear(d_model, d_state, bias=False)
# delta (스텝 크기) 투영
self.dt_proj = nn.Linear(d_model, d_model, bias=True)
# skip connection
self.D = nn.Parameter(torch.ones(d_model))
def selective_scan(
self,
x: torch.Tensor,
dt: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
) -> torch.Tensor:
"""선택적 스캔 알고리즘 (순차 구현).
실제 Mamba는 GPU SRAM 최적화된 CUDA 커널을 사용하지만,
여기서는 알고리즘의 논리를 명확히 보여주기 위해 순차 구현한다.
Args:
x: (batch, seq_len, d_model) 입력
dt: (batch, seq_len, d_model) 이산화 스텝 크기
B: (batch, seq_len, d_state) 입력 투영
C: (batch, seq_len, d_state) 출력 투영
Returns:
y: (batch, seq_len, d_model) 출력
"""
batch, seq_len, d_model = x.shape
d_state = B.shape[-1]
# A 이산화: A_bar = exp(dt * A)
A = -torch.exp(self.A_log) # (d_model, d_state)
dt_A = torch.einsum('bld,dn->bldn', dt, A) # (B, L, D, N)
A_bar = torch.exp(dt_A)
# B 이산화: B_bar = dt * B
dt_B = torch.einsum('bld,bln->bldn', dt, B) # (B, L, D, N)
# 순차 스캔
h = torch.zeros(batch, d_model, d_state, device=x.device)
outputs = []
for t in range(seq_len):
# h_t = A_bar_t * h_{t-1} + B_bar_t * x_t
h = A_bar[:, t] * h + dt_B[:, t] * x[:, t].unsqueeze(-1)
# y_t = C_t @ h_t
y_t = torch.einsum('bdn,bn->bd', h, C[:, t])
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # (B, L, D)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""선택적 SSM의 forward pass.
Args:
x: (batch, seq_len, d_model)
Returns:
y: (batch, seq_len, d_model)
"""
# 입력 의존적 파라미터 계산
B = self.B_proj(x) # (B, L, N)
C = self.C_proj(x) # (B, L, N)
dt = F.softplus(self.dt_proj(x)) # (B, L, D), 양수 보장
# 선택적 스캔
y = self.selective_scan(x, dt, B, C)
# skip connection
y = y + self.D.unsqueeze(0).unsqueeze(0) * x
return y
```### Hardware-Awareスキャンアルゴリズム
Selection Mechanismを導入するとSSMはもはやLTIではなくなるため、S4で使用していたFFTベースの合成積を適用することはできません。時間変動パラメータは各時点で異なるため、合成積カーネルを事前計算することは不可能です。シーケンシャルスキャンを使用する必要がありますが、単純なforループ実装はGPUの並列性を利用できず、非常に遅いです。
MambaはFlashAttentionに触発されたhardware-awareアルゴリズムでこの問題を解決します。重要なアイデアは3つです。まず、GPU HBM(High Bandwidth Memory)とSRAM(on-chip memory)間のデータ移動を最小限に抑えるカーネルフュージョンを適用する。二酸化、スキャン、出力投影を1つのCUDAカーネルで実行し、中間結果をSRAMに保持します。次に、中間状態テンソル($ \ bar {A}、\ bar {B} $など)をHBMに書き込むことなくSRAMでリアルタイムに再計算します。これにより、メモリ使用量はシーケンス長に比例しません。第三に、逆電波でも中間状態を保存せずに再計算(recomputation)してメモリを節約する。
## Mamba アーキテクチャの詳細
### Mambaブロック構造
MambaブロックはTransformerブロックとは異なる構造を持っています。 TransformerはSelf-AttentionとFFNの2つのサブレイヤーで構成されていますが、Mambaはこれら2つの役割を1つのブロックに統合します。具体的には、入力を2つの分岐に分岐させ、1つの分岐は合成積とオプションのSSMを経て、もう一方の分岐はゲートとして機能します。```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class MambaConfig:
d_model: int = 768
d_state: int = 16
d_conv: int = 4
expand: int = 2
n_layers: int = 24
vocab_size: int = 50257
dropout: float = 0.0
class MambaBlock(nn.Module):
"""Mamba 블록: 선택적 SSM, 합성곱, 게이팅을 통합한 구조.
Transformer의 Attention + FFN을 하나의 블록으로 대체한다.
확장 비율(expand)은 내부 차원을 결정하며,
기본값 2는 d_inner = 2 * d_model을 의미한다.
"""
def __init__(self, config: MambaConfig):
super().__init__()
self.d_model = config.d_model
self.d_state = config.d_state
self.d_conv = config.d_conv
d_inner = config.d_model * config.expand
# 입력 투영: 두 갈래(z, x)로 분기
self.in_proj = nn.Linear(config.d_model, d_inner * 2, bias=False)
# 1D 합성곱: 지역적 문맥 포착
self.conv1d = nn.Conv1d(
in_channels=d_inner,
out_channels=d_inner,
kernel_size=config.d_conv,
padding=config.d_conv - 1,
groups=d_inner, # depthwise convolution
bias=True,
)
# 선택적 SSM 파라미터
self.B_proj = nn.Linear(d_inner, config.d_state, bias=False)
self.C_proj = nn.Linear(d_inner, config.d_state, bias=False)
self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
# A 행렬: 대각 구조
A = torch.arange(1, config.d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(
torch.log(A).unsqueeze(0).expand(d_inner, -1).clone()
)
self.D = nn.Parameter(torch.ones(d_inner))
# 출력 투영
self.out_proj = nn.Linear(d_inner, config.d_model, bias=False)
self.norm = nn.RMSNorm(config.d_model)
def _selective_scan(self, x, dt, B, C):
"""선택적 스캔 (순차 구현). 실제로는 CUDA 커널 사용."""
batch, seq_len, d_inner = x.shape
d_state = B.shape[-1]
A = -torch.exp(self.A_log)
dt_A = torch.einsum('bld,dn->bldn', dt, A)
A_bar = torch.exp(dt_A)
dt_B_x = torch.einsum('bld,bln->bldn', dt * x, B)
h = torch.zeros(batch, d_inner, d_state, device=x.device)
outputs = []
for t in range(seq_len):
h = A_bar[:, t] * h + dt_B_x[:, t]
y_t = torch.einsum('bdn,bn->bd', h, C[:, t])
outputs.append(y_t)
y = torch.stack(outputs, dim=1)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Mamba 블록의 forward pass.
Args:
x: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
"""
residual = x
x = self.norm(x)
# 두 갈래로 분기
xz = self.in_proj(x)
x_branch, z = xz.chunk(2, dim=-1)
# 합성곱 경로: 지역적 패턴 포착
x_conv = x_branch.transpose(1, 2) # (B, D_inner, L)
x_conv = self.conv1d(x_conv)[:, :, :x.shape[1]] # causal trim
x_conv = x_conv.transpose(1, 2) # (B, L, D_inner)
x_conv = F.silu(x_conv)
# 선택적 SSM
B = self.B_proj(x_conv)
C = self.C_proj(x_conv)
dt = F.softplus(self.dt_proj(x_conv))
y = self._selective_scan(x_conv, dt, B, C)
y = y + self.D.unsqueeze(0).unsqueeze(0) * x_conv
# 게이팅: z 갈래와 원소곱
y = y * F.silu(z)
# 출력 투영 + 잔차 연결
output = self.out_proj(y)
return output + residual
```### パラメータ効率とアーキテクチャの比較
Mambaのアーキテクチャは、Transformerと比較していくつかの構造的違いを持っています。 Transformerは、Multi-Head AttentionとFFNという2つの大きなサブレイヤーを持ち、それぞれ別々のLayer Normalizationと残差接続を持つ。 Mambaはこれを1つのブロックに統合しますが、合成積でローカルパターンをキャプチャし、オプションのSSMでグローバル依存関係をモデル化し、ゲートで情報の流れを制御します。
TransformerのAttentionがすべてのトークンペアの関係を明示的に計算する「グローバル参照(global attention)」方式である場合、MambaのオプションのSSMは情報を圧縮状態ベクトルに蓄積する「圧縮記憶(compressed memory)」方式である。この違いは、シーケンス長の複雑さの違いの根本的な原因です。
同じパラメータ数(約2.8B)に基づいて、MambaはTransformerと比較して約1.5倍少ないFLOPを使用します。 Attention の $O(L^2)$ 演算が消え、FFN と同様の演算量の選択的 SSM に置き換えられるからだ。
## Mamba-2とSSDフレームワーク
### State Space Dualityの発見
Mamba-2(Dao&Gu、2024)の最も重要な理論的貢献は、State Space Duality(SSD)の発見です。 SSMとAttentionは表面的には非常に異なる演算のように見えますが、数学的には同じ構造的行列に属することを証明しました。
具体的には、オプションのSSMの出力は、次の行列 - ベクトル積で表すことができます。
$$y = Mx$$
ここで $M$ は semi-separable 行列です。この行列の $(i, j)$ 要素は次のようになります。
$$M_{ij} = \begin{cases} C_i^\top \bar{A}_{i:j} B_j & \text{if } i \geq j \\ 0 & \text{if } i < j \end{cases}$$
ここで、 $\bar{A}_{i:j} = \bar{A}_i \bar{A}_{i-1} \cdots \bar{A}_{j+1}$ は、時点 $j$ から $i$ までの累積遷移を意味します。この行列は下三角構造を持ち、これはCausal AttentionのマスクされたAttention行列と同じ構造です。
Attention 行列 $\text{softmax}(QK^\top / \sqrt{d})$ から softmax を削除すると $QK^\top$ になり、これも rank-$d$ semi-separable 行列です。 SSDはこの接続を定式化し、SSMとLinear Attentionが同じ行列代数構造を共有することを示しています。
### SSDアルゴリズムとチャンク分割
SSDの実用的な意味は、新しいアルゴリズム設計を可能にすることです。 Mamba-2はシーケンスを固定サイズのチャンクに分割し、チャンク内では行列積として並列化し、チャンク間にSSM再帰に状態を渡すハイブリッドアルゴリズムを使用します。```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SSDBlock(nn.Module):
"""Mamba-2의 Structured State Space Duality(SSD) 블록.
시퀀스를 청크로 나누어 청크 내부는 행렬 곱(병렬),
청크 간에는 SSM 재귀로 처리하는 하이브리드 알고리즘.
"""
def __init__(
self,
d_model: int = 768,
d_state: int = 128,
n_heads: int = 8,
chunk_size: int = 256,
):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.n_heads = n_heads
self.chunk_size = chunk_size
self.head_dim = d_model // n_heads
# 멀티헤드 구조: Q(C), K(B), V(x) 투영
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_state * n_heads, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.dt_proj = nn.Linear(d_model, n_heads, bias=True)
# A 행렬 (헤드별 스칼라)
self.A_log = nn.Parameter(torch.log(torch.arange(1, n_heads + 1, dtype=torch.float32)))
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.norm = nn.RMSNorm(d_model)
def _chunk_scan(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
dt: torch.Tensor,
) -> torch.Tensor:
"""청크 기반 SSD 스캔.
Args:
Q: (batch, n_heads, seq_len, head_dim) -- C 역할
K: (batch, n_heads, seq_len, d_state) -- B 역할
V: (batch, n_heads, seq_len, head_dim) -- x 역할
dt: (batch, n_heads, seq_len) -- delta 역할
Returns:
y: (batch, n_heads, seq_len, head_dim)
"""
B, H, L, D = Q.shape
C = self.chunk_size
# 시퀀스를 청크로 분할
n_chunks = (L + C - 1) // C
pad_len = n_chunks * C - L
if pad_len > 0:
Q = F.pad(Q, (0, 0, 0, pad_len))
K = F.pad(K, (0, 0, 0, pad_len))
V = F.pad(V, (0, 0, 0, pad_len))
dt = F.pad(dt, (0, pad_len))
# 청크 단위로 reshape
Q = Q.view(B, H, n_chunks, C, D)
K = K.view(B, H, n_chunks, C, -1)
V = V.view(B, H, n_chunks, C, D)
dt = dt.view(B, H, n_chunks, C)
# 청크 내부: attention-like 행렬 곱 (병렬 처리 가능)
# 청크 내부에서 Q와 K의 유사도 행렬을 계산
# decay를 적용한 causal mask 생성
A = -torch.exp(self.A_log) # (n_heads,)
# 누적 감쇠율 계산
dt_cumsum = dt.cumsum(dim=-1) # (B, H, n_chunks, C)
# 청크 내부 attention 행렬
# M[i,j] = exp(A * (dt_cumsum[i] - dt_cumsum[j])) for i >= j
decay_diff = dt_cumsum.unsqueeze(-1) - dt_cumsum.unsqueeze(-2) # (B,H,nc,C,C)
decay = torch.exp(A.view(1, H, 1, 1, 1) * decay_diff)
causal_mask = torch.tril(torch.ones(C, C, device=Q.device))
decay = decay * causal_mask
# 청크 내부 출력 계산
attn = torch.einsum('bhncqd,bhnckd->bhncqk', Q, K) * decay
y_intra = torch.einsum('bhncqk,bhnckd->bhncqd', attn, V)
# 최종 reshape
y = y_intra.view(B, H, n_chunks * C, D)
return y[:, :, :L]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""SSD 블록의 forward pass."""
residual = x
x = self.norm(x)
B, L, D = x.shape
# 멀티헤드 투영
Q = self.q_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(B, L, self.n_heads, self.d_state).transpose(1, 2)
V = self.v_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
dt = F.softplus(self.dt_proj(x)).transpose(1, 2) # (B, H, L)
# SSD 스캔
y = self._chunk_scan(Q, K, V, dt)
# 출력 투영
y = y.transpose(1, 2).contiguous().view(B, L, D)
return self.out_proj(y) + residual
```### Mamba-2の構造改善
Mamba-2は、Mamba-1と比較して以下の構造的変更を導入した。
まず、マルチヘッド構造の導入です。 Mamba-1はチャンネルごとに独立したSSMを適用しましたが、Mamba-2はAttentionと同様のヘッド構造を使用してパラメータを共有します。これにより、同じ演算量でより多くの状態次元を活用することができる。
第二に、A行列の単純化です。 Mamba-1はチャンネルごとの対角A行列を使用しましたが、Mamba-2はヘッドごとのスカラー値にさらに簡素化しました。この単純化がSSDの行列分解を可能にする重要な条件である。
第三に、チャンクベースの並列アルゴリズムです。シーケンスを固定サイズのチャンクに分割し、チャンク内ではTensor Coreを利用した行列積として処理し、チャンク間はSSM再帰に状態を渡します。この方式は純粋再帰より2-8倍速く、GPUの並列ハードウェアを最大限活用する。
## Transformerと比較した性能比較
### 定量的ベンチマーク
様々なシーケンスモデリングアーキテクチャの性能を比較した結果をまとめる。測定条件は、約2.8Bパラメータスケール、同じ学習データ(The Pile、300Bトークン)を使用した。
|アイテム|トランスフォーマー(2.8B)|マンバ(2.8B)| Mamba-2(2.8B)| RWKV-6(3B)| RetNet(2.7B)|
| --------------------- | ------------------ | ------------ | -------------- | ----------- | ------------- |
|パイルPPL(val)| 6.73 | 6.22 | 6.10 | 6.85 | 7.02 |
|学習スループット(tok / s)|基準(1.0x)| 1.2x | 1.8x | 1.1x | 1.3x |
|推論遅延(1Kトークン)| 38ms/tok | 12ms/tok | 10ms/tok | 14ms/tok | 16ms/tok |
|推論遅延(16Kトークン)| 180ms/tok | 12ms/tok | 10ms/tok | 14ms/tok | 16ms/tok |
|推論メモリ(1Kトークン)| 2.1 GB | 0.8 GB | 0.7 GB | 0.9 GB | 1.0 GB |
|推論メモリ(16K)| 8.5 GB | 0.8 GB | 0.7 GB | 0.9 GB | 1.0 GB |
|シーケンス長拡張| KV Cache線形増加|定数定数定数定数
| In-Context Learning |強さ|中ミドル - 強さ弱さ弱さ
|アテンションパターンの解釈可能(視覚化)|不可部分的不可部分的
この表で注目すべき点はいくつかあります。まず、Mambaシリーズは、シーケンス長が増加しても推論遅延とメモリが定数に保たれる。 Transformerは16Kトークンで1Kに対して約4.7倍の遅延増加を見せるが、Mambaは同じだ。第二に、Mamba-2はMamba-1よりも学習スループットで1.5倍速い。 SSDアルゴリズムのチャンクベースの並列化効果です。第三に、In-Context Learningでは、Transformerは依然として優位を示しています。これはSSMベースのモデルの重要な弱点の1つです。
### シーケンス長によるスケーリング
シーケンス長による性能変化を分析すると、SSMの利点がより明らかになります。 TransformerのSelf-Attentionは$O(L^2)$なので、シーケンス長が2倍増加すると演算量が4倍増加する。 FlashAttention-2のような最適化を適用しても、この基本的な複雑さは変わりません。一方、MambaのオプションのSSMは$ O(L)$なので、シーケンスの長さに関係なくトークンあたりの演算量は一定です。
実質的な壁は約32K-64Kトークン区間に現れる。この区間では、A100 80GB GPUベースで、TransformerはKV Cacheのためにバッチサイズを減らす必要があり、Mambaは同じバッチサイズを維持できます。 1Mトークン以上の非常に長いシーケンスでは、トランスフォーマーは実用的に不可能ですが、SSMは原理的に処理可能です。
## さまざまなドメインの適用
###ビジョン:Vision MambaとVMamba
SSMをビジョンタスクに適用する研究が活発です。 Vision Mamba(Vim、Zhu et al。、2024)は、ViT(Vision Transformer)のSelf-Attentionを双方向SSMに置き換える。画像パッチを順方向と逆方向にスキャンして双方向コンテキストをキャプチャし、ImageNet分類でDeiTと同等の精度を2.8倍少ないGPUメモリで達成した。VMamba(Liu et al。、2024)は、2D画像の空間構造をよりよく反映するためにCross-Scan Module(CSM)を提案しました。画像を4方向(左上 - 右下、右上 - 左下、左下 - 右上、右下 - 左上)にスキャンして、2D空間情報を1Dシーケンスモデルでキャプチャします。
ビジョン分野におけるSSMの重要な利点は、高解像度画像処理です。 ViTはパッチ数(シーケンス長)の二乗に比例する演算量を持つため、224×224から512×512に解像度を上げると演算量が約5倍増加する。 Vision Mambaは、線形複雑さのおかげで、このコストの増加は約2倍にすぎません。
###オーディオと音声
オーディオ信号は本質的に長期シーケンスである。 16kHzサンプリングでは、1分のオーディオは960,000の視点を持ち、それをTransformerで直接処理することは非現実的です。 Audio Mamba(Erol et al。、2024)は、オーディオ分類タスクのAudio Spectrogram Transformer(AST)に代わって、同等の精度でメモリ使用量を70%削減しました。
音声合成(TTS)分野でもMambaの適用が試みられている。長いオーディオシーケンスの条件付き生成では、トランスフォーマーのKV Cacheボトルネックが深刻であるため、SSMベースのデコーダはリアルタイム音声合成に適しています。
### 時系列予測
時系列データは、SSMが最も自然に適用されるドメインです。これは、連続時間動的システムをモデル化するSSMの数学的フレームワークが時系列の連続特性によく合うためです。 S-Mamba(Wang et al。、2024)は、多変量時系列予測でPatchTSTと比較して同等または優れた性能を示し、推論速度は3倍速い。
時系列予測におけるSSMの特別な利点は、不規則な時系列処理です。連続時間SSMは、離散ステップ$ \ Delta $を観測間隔に合わせて調整することができるので、不規則な時間間隔のデータを補間なしで直接処理することができる。これは、医療データ、IoTセンサーデータなどの大きな利点です。
###ゲノミクスとタンパク質
DNA配列は非常に長い配列(数百万〜数十億bp)であり、長距離相互作用が生物学的に重要である。 Caduceus(Schiff et al。、2024)は、双方向MambaベースのDNA言語モデルであり、DNAの二本鎖構造と逆相補性をアーキテクチャに反映した。従来のDNA Transformerモデル(Nucleotide Transformer、HyenaDNA)と比較して、長距離変異効果の予測において優れた性能を示した。
##制限点と課題
### In-Context Learningの弱点
Mambaを含むSSMベースのモデルの最も深刻な制限は、In-Context Learning(ICL)能力です。 Transformer は、プロンプトで提供される few-shot の例を参照して、新しいタスクを実行する ICL 能力に優れています。これは、Self-Attentionがシーケンス内の任意の位置を直接参照できるためです。
一方、SSMは情報を固定サイズ状態ベクトルに圧縮するので、プロンプトの特定の例を正確に「検索」することは困難である。状態次元$ N $がシーケンス長よりはるかに小さい場合($ N \ ll L $)、情報の損失は避けられない。 Mamba-2は状態次元を大きく増やすことで(N = 128以上)、この問題を部分的に軽減しましたが、Transformerの明示的なトークンレベル参照能力には及ばない。
### Retrievalタスクの制限
Information Retrieval、つまりシーケンスから特定の情報を正確に抽出するタスクでは、SSMはTransformerと比較して明確な弱点を示しています。たとえば、「passkey retrieval」タスクで10万トークンの干し草の山から隠された暗号鍵を見つけるという問題をMambaは不安定に実行します。これは、SSMの圧縮記憶方式が正確な情報保存ではなく統計的要約に適しているためです。
### ハイブリッドアーキテクチャの浮上
この限界を認識し、最近の研究はSSMとAttentionを組み合わせたハイブリッドアーキテクチャを探求している。 Jamba(AI21 Labs、2024)は、Mamba層とTransformer層を交差配置し、SSMの効率とAttentionのICL能力を同時に確保した。 52Bパラメータで256Kコンテキストをサポートし、純粋なTransformerと比較して推論スループットが3倍高い。
Zamba(Zyphra、2024)とGriffin(De et al。、2024)なども同様のハイブリッドアプローチを採用しています。
##本番実装ガイド
###公式Mambaライブラリの活用
マンバの公式実装は`state-spaces/mamba`リポジトリで提供されます。 CUDAカーネルが含まれており、GPU環境が不可欠です。```bash
# Mamba 설치 및 기본 사용법
# 1. 설치 (CUDA 11.8 이상 필요)
pip install mamba-ssm
# 2. causal-conv1d 의존성 설치
pip install causal-conv1d>=1.2.0
# 3. 사전학습된 모델 사용 (Hugging Face)
pip install transformers
# 4. 사전학습된 Mamba 모델로 텍스트 생성
python -c "
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
# Mamba-2.8B 모델 로드
model = MambaLMHeadModel.from_pretrained(
'state-spaces/mamba-2.8b',
device='cuda',
dtype=torch.float16,
)
model.eval()
# 토크나이저 (GPT-NeoX 토크나이저 호환)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
# 텍스트 생성
prompt = 'State Space Models are'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
output = model.generate(
input_ids=input_ids,
max_length=100,
temperature=0.7,
top_p=0.9,
)
print(tokenizer.decode(output[0]))
"
```### ベンチマークスクリプト
MambaとTransformerの推論性能を定量的に比較するベンチマークスクリプトだ。```python
import torch
import time
from dataclasses import dataclass
@dataclass
class BenchmarkResult:
model_name: str
seq_len: int
batch_size: int
latency_ms: float
memory_mb: float
tokens_per_sec: float
def benchmark_inference(
model: torch.nn.Module,
model_name: str,
seq_lengths: list[int],
batch_size: int = 1,
n_warmup: int = 5,
n_measure: int = 20,
device: str = "cuda",
) -> list[BenchmarkResult]:
"""다양한 시퀀스 길이에서 추론 성능을 측정한다.
Args:
model: 벤치마크 대상 모델
model_name: 모델 이름 (결과 표시용)
seq_lengths: 테스트할 시퀀스 길이 목록
batch_size: 배치 크기
n_warmup: 워밍업 반복 수
n_measure: 측정 반복 수
device: 실행 디바이스
Returns:
시퀀스 길이별 벤치마크 결과 목록
"""
model = model.to(device).eval()
results = []
for seq_len in seq_lengths:
input_ids = torch.randint(
0, 50257, (batch_size, seq_len), device=device
)
# GPU 캐시 정리
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# 워밍업
with torch.no_grad():
for _ in range(n_warmup):
_ = model(input_ids)
torch.cuda.synchronize()
# 측정
latencies = []
with torch.no_grad():
for _ in range(n_measure):
torch.cuda.synchronize()
start = time.perf_counter()
_ = model(input_ids)
torch.cuda.synchronize()
end = time.perf_counter()
latencies.append((end - start) * 1000)
avg_latency = sum(latencies) / len(latencies)
peak_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
tokens_per_sec = (batch_size * seq_len) / (avg_latency / 1000)
results.append(BenchmarkResult(
model_name=model_name,
seq_len=seq_len,
batch_size=batch_size,
latency_ms=avg_latency,
memory_mb=peak_memory,
tokens_per_sec=tokens_per_sec,
))
print(
f"[{model_name}] seq_len={seq_len:>6d} | "
f"latency={avg_latency:>8.2f}ms | "
f"memory={peak_memory:>8.1f}MB | "
f"throughput={tokens_per_sec:>10.0f} tok/s"
)
return results
def compare_models(results_dict: dict[str, list[BenchmarkResult]]):
"""모델 간 성능 비교 테이블을 출력한다."""
print("\n" + "=" * 80)
print(f"{'모델':<20} {'시퀀스 길이':>10} {'지연(ms)':>10} "
f"{'메모리(MB)':>12} {'처리량(tok/s)':>15}")
print("=" * 80)
for model_name, results in results_dict.items():
for r in results:
print(
f"{r.model_name:<20} {r.seq_len:>10d} "
f"{r.latency_ms:>10.2f} {r.memory_mb:>12.1f} "
f"{r.tokens_per_sec:>15.0f}"
)
print("-" * 80)
```##操作時の注意事項
### CUDAカーネル互換性
Mambaの公式CUDAカーネルは、特定のGPUアーキテクチャ(Ampere以上、sm_80+)に最適化されています。 V100(sm_70)以前の世代のGPUでは、フォールバック実装が使用され、パフォーマンスが大幅に低下します。本番展開前に、ターゲットGPUのコンピューティング能力を確認し、Ampere(A100 / A10G)またはHopper(H100)GPUを使用する必要があります。
causal-conv1dパッケージのバージョン互換性にも注意する必要があります。 mamba-ssm バージョンと causal-conv1d バージョンが合わないと、ランタイムエラーが発生します。公式READMEの互換性表を必ず確認してください。
### 状態初期化とシーケンス境界
SSMベースのモデルの推論で最も一般的な間違いは、状態初期化を誤って処理することです。バッチ推論では、異なるシーケンスの状態が混在すると、出力が汚染されます。各シーケンスの開始時に状態を明示的にゼロベクトルに初期化し、ストリーミング推論で会話ターンが変わったときにも状態をリセットする必要があります。
連続会話(multi-turn conversation)で前回のターンの状態を維持するかリセットするかは設計決定である。状態を維持すると、以前の会話のコンテキストは保存されますが、状態の情報保存能力はTransformerのKV Cacheよりも制限されているため、長い会話では情報の損失が蓄積される可能性があります。
### 学習時の精度と安定性
MambaのオプションのSSMは、softplusアクティベーションを使用する$ \ Delta $パラメータの数値安定性に敏感です。 $\Delta$ が大きすぎると $\exp(\Delta A)$ が 0 に収束して傾きが消失し、小さすぎると状態更新が少なくて学習が停滞する。学習の初めに$ \ Delta $の初期化を$ \ text {Uniform}(-4、-1)$に設定して、ログスペースの適切な範囲を確保することをお勧めします。
BF16学習は一般的に安定していますが、状態行列の指数演算では断続的に数値不安定が発生する可能性があります。ログ空間でAをパラメータ化($A = -\exp(A_{\log})$)すると、負の水性と安定性を同時に保証できます。
###サービングフレームワークのサポート状況
vLLMはMambaモデルのサービングを部分的にサポートします。 PagedAttentionはAttentionベースのモデルであるため、SSMモデルには適用されませんが、vLLMのスケジューリングとバッチインフラストラクチャを利用できます。 TensorRT-LLMはMamba-1モデルの最適化をサポートし、Mamba-2は別途プラグインが必要です。
Tritonサービングの場合、MambaモデルをONNXに変換して配布することは可能ですが、CUDAカスタムカーネルがONNXグラフに含まれていないため、パフォーマンスが低下します。現時点では、mamba-ssmライブラリを直接ラップするカスタムサービスサーバーが最も信頼性の高い選択肢です。
### デバッグチェックリスト
本番で発生する可能性のある問題と対応方法をまとめます。
- **出力が繰り返しループに陥るとき**:状態が特定のパターンに固着された状態である。 temperatureを上げるか、状態を部分的にリセットするロジックを追加します。
- **長い入力で品質が急激に低下した場合**:状態ディメンションがシーケンス長に比べて不足している。モデルを選択するには、d_stateが十分に大きいモデル(N = 64以上)を使用します。
- **バッチ推論でシーケンス別結果が異なる場合**: パディング処理が誤ってパディングトークンの情報が状態に流入した可能性が高い。パディングトークンの $\Delta$ をゼロに強制し、ステータス更新をブロックします。
- **CUDA OOMが学習中に断続的に発生した場合**:再計算戦略が正しく適用されていない可能性があります。 gradient checkpointingを有効にし、バッチサイズを減らします。
##参考資料
1. Gu、A.、&Dao、T.(2023)。 Mamba: Linear-Time Sequence Modeling with Selective State Spaces. _arXiv preprint arXiv:2312.00752_。 [https://arxiv.org/abs/2312.00752](https://arxiv.org/abs/2312.00752)
2. Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. _arXiv preprint arXiv:2405.21060_。 [https://arxiv.org/abs/2405.21060](https://arxiv.org/abs/2405.21060)
3. Mambaの公式GitHubリポジトリ。 [https://github.com/state-spaces/mamba](https://github.com/state-spaces/mamba)
4. Gu, A., Goel, K., & Re, C. (2021). Efficiently Modeling Long Sequences with Structured State Spaces (S4). _ICLR 2022_。 [https://arxiv.org/abs/2111.00396](https://arxiv.org/abs/2111.00396)5. Patro, B. N., & Agneeswaran, V. S. (2024). Mamba-360: Survey of State Space Models as Transformer Alternative for Long Sequence Modelling. _arXiv preprint arXiv:2404.16112_。 [https://arxiv.org/abs/2404.16112](https://arxiv.org/abs/2404.16112)
6. Lieber, O. et al. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. _arXiv preprint arXiv:2403.19887_。 [https://arxiv.org/abs/2403.19887](https://arxiv.org/abs/2403.19887)
7. Zhu、L. et al。 (2024). Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model. _ICML 2024_。 [https://arxiv.org/abs/2401.09417](https://arxiv.org/abs/2401.09417)
8. Schiff, Y. et al. (2024). Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling. _ICML 2024_。 [https://arxiv.org/abs/2403.03234](https://arxiv.org/abs/2403.03234)