- Published on
Mamba 논문 리뷰: Selective State Space Models로 Transformer를 넘어서
- Authors
- Name
- 1. 서론: Transformer의 한계와 SSM의 부상
- 2. 배경: Structured State Space Models (S4)
- 3. Mamba: Selective State Space Models
- 4. Mamba-2: State Space Duality
- 5. 실전: Mamba 모델 사용하기
- 6. Mamba vs Transformer 비교
- 7. 한계와 전망
- 8. 퀴즈
1. 서론: Transformer의 한계와 SSM의 부상
Transformer 아키텍처는 2017년 등장 이후 NLP, Vision, Audio 등 거의 모든 시퀀스 모델링 분야를 지배해왔다. 하지만 근본적인 한계가 있다:
- O(N²) 복잡도: Self-Attention의 시간/공간 복잡도가 시퀀스 길이에 이차적으로 증가
- 추론 비효율: 각 토큰 생성 시 전체 KV Cache를 참조해야 함
- 긴 시퀀스 처리 어려움: 128K+ 컨텍스트에서 메모리 병목
**State Space Model(SSM)**은 이러한 한계를 극복하기 위해 제안된 대안이다. 특히 Mamba(Gu & Dao, ICLR 2024)는 SSM에 선택적(Selective) 메커니즘을 도입하여 Transformer에 필적하는 성능을 O(N) 복잡도로 달성했다.
2. 배경: Structured State Space Models (S4)
2.1 연속 시간 SSM
SSM은 연속 시간 시스템을 이산화한 것이다:
- : hidden state
- : input
- : output
- , ,
2.2 이산화 (Zero-Order Hold)
연속 시스템을 이산 시간으로 변환:
이산화된 재귀 관계:
2.3 S4의 핵심: HiPPO 초기화
S4(Structured State Spaces for Sequence Modeling)의 핵심 기여는 HiPPO(High-order Polynomial Projection Operator) 행렬로 A를 초기화한 것이다:
import torch
import numpy as np
def make_hippo(N):
"""HiPPO-LegS 행렬 생성"""
P = np.sqrt(1 + 2 * np.arange(N))
A = np.zeros((N, N))
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = P[i] * P[j]
elif i == j:
A[i, j] = i + 1
# i < j: 0
return -A # 안정성을 위해 음수
2.4 S4의 한계
S4는 입력에 독립적인 고정 파라미터(A, B, C, Δ)를 사용한다. 이는 두 가지 문제를 야기한다:
- Content-based reasoning 불가: 입력 내용에 따라 다르게 처리해야 하는 태스크(예: selective copying, induction heads)에서 성능 저하
- Attention의 장점 포기: Transformer의 content-aware 매칭 능력 상실
3. Mamba: Selective State Space Models
3.1 핵심 아이디어: Selection Mechanism
Mamba의 핵심 혁신은 SSM 파라미터를 입력에 의존하게 만든 것이다:
S4: (A, B, C, Δ) = 고정 파라미터
Mamba: (B, C, Δ) = f(input) ← 입력 의존!
구체적으로:
이렇게 하면 모델이 어떤 정보를 기억하고 어떤 정보를 잊을지 입력에 따라 동적으로 결정할 수 있다.
3.2 Selection의 직관
(step size)가 핵심:
- 가 크면: → 이전 state 유지 (입력 무시)
- 가 작으면: 현재 입력에 더 집중 (state 갱신)
이것은 게이팅 메커니즘과 유사하다:
LSTM의 forget gate ≈ Mamba의 Δ
3.3 Hardware-Aware 알고리즘
Selection을 도입하면 파라미터가 입력 의존이 되어 convolution trick을 사용할 수 없다. 이는 학습 시 O(N²) 복잡도로 회귀할 위험이 있다.
Mamba는 이를 kernel fusion + recomputation으로 해결:
# Mamba의 Selective Scan (simplified)
def selective_scan(x, delta, A, B, C):
"""
x: (B, L, D) - 입력
delta: (B, L, D) - step size (입력 의존)
A: (D, N) - state matrix
B: (B, L, N) - input matrix (입력 의존)
C: (B, L, N) - output matrix (입력 의존)
"""
B_batch, L, D = x.shape
N = A.shape[1]
# 이산화
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
# Sequential scan (학습 시 parallel scan 사용)
h = torch.zeros(B_batch, D, N, device=x.device)
ys = []
for t in range(L):
h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]
y = (h * C[:, t, None, :]).sum(-1) # (B, D)
ys.append(y)
return torch.stack(ys, dim=1) # (B, L, D)
실제 구현에서는 CUDA 커널에서 SRAM에 모든 중간 상태를 유지하여 HBM 접근을 최소화한다 (FlashAttention과 유사한 IO-aware 접근).
3.4 Mamba Block 아키텍처
Input
│
├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐
│ │
└──── Linear(D → ED) ──── SiLU ──────────────────────────── × ───────┘
│
Linear(ED → D)
│
Output
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,
padding=d_conv-1, groups=d_inner)
# SSM 파라미터
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, Δ
self.dt_proj = nn.Linear(1, d_inner, bias=True)
A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
# x: (B, L, D)
xz = self.in_proj(x) # (B, L, 2*ED)
x, z = xz.chunk(2, dim=-1)
# Conv + SiLU
x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]
x = x.transpose(1, 2)
x = F.silu(x)
# SSM
A = -torch.exp(self.A_log)
B_C_dt = self.x_proj(x)
# ... selective scan ...
# Gate
y = y * F.silu(z)
return self.out_proj(y)
4. Mamba-2: State Space Duality
4.1 SSD (State Space Dual) 모델
Mamba-2(Dao & Gu, ICML 2024)는 SSM과 Attention 사이의 수학적 이중성(Duality)을 발견했다:
- SSM 관점: 재귀적 state update → O(N) 추론
- Attention 관점: 특수한 구조의 행렬 곱 → 병렬 학습
SSM recurrence: h_t = A_t h_{t-1} + B_t x_t
y_t = C_t h_t
↕ Dual ↕
Structured Attention: Y = (T ⊙ M) X
where T = lower-triangular causal mask
M = structured (semi-separable) matrix
4.2 성능 비교
| 모델 | 파라미터 | Pile (ppl) | 학습 FLOPS/s | 추론 (tok/s) |
|---|---|---|---|---|
| Transformer++ | 2.7B | 6.7 | 1.0x | 1.0x |
| Mamba-1 | 2.8B | 6.2 | 1.3x | 5.2x |
| Mamba-2 | 2.7B | 6.1 | 2.1x | 5.5x |
핵심 개선:
- 학습 속도 2배 향상 (structured matrix multiplication 활용)
- 더 큰 state size 가능 (N=64 → N=256)
- Multi-head 구조 도입
5. 실전: Mamba 모델 사용하기
5.1 설치 및 추론
pip install mamba-ssm causal-conv1d>=1.4.0
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
# Mamba-2.8B 로드
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = model.cuda().half()
# 추론
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=100, temperature=0.7)
print(tokenizer.decode(output[0]))
5.2 Mamba-2 사용
from mamba_ssm import Mamba2
# Mamba-2 레이어 단독 사용
layer = Mamba2(
d_model=2048,
d_state=128, # Mamba-2는 더 큰 state 가능
d_conv=4,
expand=2,
headdim=64, # Multi-head
).cuda()
x = torch.randn(1, 1024, 2048).cuda() # (batch, seq_len, d_model)
y = layer(x) # (1, 1024, 2048)
6. Mamba vs Transformer 비교
| 특성 | Transformer | Mamba |
|---|---|---|
| 학습 복잡도 | O(N²) | O(N) |
| 추론 복잡도 | O(N) per token (KV Cache) | O(1) per token |
| 메모리 (추론) | O(N) KV Cache | O(1) fixed state |
| In-context learning | ✅ 강함 | ⚠️ 약함 (개선 중) |
| 긴 시퀀스 | ❌ 병목 | ✅ 효율적 |
| 하드웨어 활용 | ✅ 병렬화 최적 | ⚠️ 재귀 병목 (Mamba-2 개선) |
7. 한계와 전망
7.1 현재 한계
- In-context learning: 고정 크기 state로 무한한 컨텍스트 정보 압축 시 손실
- Retrieval 태스크: 정확한 토큰 검색이 필요한 경우 Attention이 우위
- 학습 안정성: 긴 시퀀스에서 gradient 불안정 문제
7.2 하이브리드 접근
최근 트렌드는 Mamba + Attention 하이브리드:
- Jamba (AI21): Mamba 레이어 + Attention 레이어 혼합
- Zamba (Zyphra): Mamba 기반 + 소수의 Shared Attention 레이어
8. 퀴즈
Q1. Mamba가 S4 대비 핵심적으로 개선한 점은?
SSM 파라미터(B, C, Δ)를 입력 의존적으로 만들어 content-based reasoning을 가능하게 한 것. S4는 고정 파라미터여서 입력 내용에 따른 선택적 처리가 불가했다.
Q2. Mamba에서 Δ(delta)의 역할은?
Step size로, 게이팅 메커니즘 역할. Δ가 크면 이전 state 유지(입력 무시), 작으면 현재 입력에 집중(state 갱신). LSTM의 forget gate와 유사.
Q3. Selection 도입 후 convolution trick을 사용할 수 없는 이유는?
Convolution trick은 시간 불변(time-invariant) 시스템에서만 작동. Selection으로 파라미터가 입력 의존이 되면 시간 가변(time-varying)이 되어 글로벌 컨볼루션 불가.
Q4. Mamba의 Hardware-Aware 알고리즘의 핵심은?
CUDA 커널에서 SRAM에 중간 state를 유지하여 HBM 접근 최소화. FlashAttention과 유사한 IO-aware 접근으로 memory-bound 연산 최적화.
Q5. Mamba-2의 State Space Duality란?
SSM의 재귀적 state update와 structured attention 행렬 곱이 수학적으로 동일하다는 것. 이를 통해 학습 시 병렬화된 행렬 곱, 추론 시 효율적 재귀를 모두 활용.
Q6. Mamba 대비 Transformer가 우위인 태스크는?
In-context learning과 정확한 토큰 검색(retrieval) 태스크. Attention은 시퀀스 내 임의 위치를 직접 참조할 수 있지만, Mamba는 고정 크기 state로 정보를 압축해야 한다.
Q7. Jamba, Zamba 같은 하이브리드 모델의 설계 원리는?
Mamba 레이어의 O(N) 효율성 + 소수의 Attention 레이어의 정확한 검색 능력을 결합. 대부분의 레이어를 Mamba로, 핵심 위치에만 Attention을 배치.