1. 서론: Transformer의 한계와 SSM의 부상
Transformer 아키텍처는 2017년 등장 이후 NLP, Vision, Audio 등 거의 모든 시퀀스 모델링 분야를 지배해왔다. 하지만 근본적인 한계가 있다:
- **O(N²) 복잡도**: Self-Attention의 시간/공간 복잡도가 시퀀스 길이에 이차적으로 증가
- **추론 비효율**: 각 토큰 생성 시 전체 KV Cache를 참조해야 함
- **긴 시퀀스 처리 어려움**: 128K+ 컨텍스트에서 메모리 병목
**State Space Model(SSM)**은 이러한 한계를 극복하기 위해 제안된 대안이다. 특히 **Mamba**(Gu & Dao, ICLR 2024)는 SSM에 **선택적(Selective)** 메커니즘을 도입하여 Transformer에 필적하는 성능을 **O(N) 복잡도**로 달성했다.
2. 배경: Structured State Space Models (S4)
2.1 연속 시간 SSM
SSM은 연속 시간 시스템을 이산화한 것이다:
$$
h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)
$$
$$
y(t) = \mathbf{C}h(t)
$$
- $h(t) \in \mathbb{R}^N$: hidden state
- $x(t) \in \mathbb{R}$: input
- $y(t) \in \mathbb{R}$: output
- $\mathbf{A} \in \mathbb{R}^{N \times N}$, $\mathbf{B} \in \mathbb{R}^{N \times 1}$, $\mathbf{C} \in \mathbb{R}^{1 \times N}$
2.2 이산화 (Zero-Order Hold)
연속 시스템을 이산 시간으로 변환:
$$
\overline{\mathbf{A}} = \exp(\Delta \mathbf{A})
$$
$$
\overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}
$$
이산화된 재귀 관계:
$$
h_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_t
$$
$$
y_t = \mathbf{C} h_t
$$
2.3 S4의 핵심: HiPPO 초기화
S4(Structured State Spaces for Sequence Modeling)의 핵심 기여는 **HiPPO(High-order Polynomial Projection Operator)** 행렬로 A를 초기화한 것이다:
def make_hippo(N):
"""HiPPO-LegS 행렬 생성"""
P = np.sqrt(1 + 2 * np.arange(N))
A = np.zeros((N, N))
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = P[i] * P[j]
elif i == j:
A[i, j] = i + 1
i < j: 0
return -A # 안정성을 위해 음수
2.4 S4의 한계
S4는 **입력에 독립적인 고정 파라미터**(A, B, C, Δ)를 사용한다. 이는 두 가지 문제를 야기한다:
1. **Content-based reasoning 불가**: 입력 내용에 따라 다르게 처리해야 하는 태스크(예: selective copying, induction heads)에서 성능 저하
2. **Attention의 장점 포기**: Transformer의 content-aware 매칭 능력 상실
3. Mamba: Selective State Space Models
3.1 핵심 아이디어: Selection Mechanism
Mamba의 핵심 혁신은 **SSM 파라미터를 입력에 의존하게** 만든 것이다:
S4: (A, B, C, Δ) = 고정 파라미터
Mamba: (B, C, Δ) = f(input) ← 입력 의존!
구체적으로:
$$
\mathbf{B}_t = \text{Linear}_B(x_t), \quad \mathbf{C}_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))
$$
이렇게 하면 모델이 **어떤 정보를 기억하고 어떤 정보를 잊을지** 입력에 따라 동적으로 결정할 수 있다.
3.2 Selection의 직관
$\Delta_t$(step size)가 핵심:
- **$\Delta_t$가 크면**: $\overline{\mathbf{A}} \approx \mathbf{I}$ → 이전 state 유지 (입력 무시)
- **$\Delta_t$가 작으면**: 현재 입력에 더 집중 (state 갱신)
이것은 **게이팅 메커니즘**과 유사하다:
LSTM의 forget gate ≈ Mamba의 Δ
3.3 Hardware-Aware 알고리즘
Selection을 도입하면 파라미터가 입력 의존이 되어 **convolution trick을 사용할 수 없다**. 이는 학습 시 O(N²) 복잡도로 회귀할 위험이 있다.
Mamba는 이를 **kernel fusion + recomputation**으로 해결:
Mamba의 Selective Scan (simplified)
def selective_scan(x, delta, A, B, C):
"""
x: (B, L, D) - 입력
delta: (B, L, D) - step size (입력 의존)
A: (D, N) - state matrix
B: (B, L, N) - input matrix (입력 의존)
C: (B, L, N) - output matrix (입력 의존)
"""
B_batch, L, D = x.shape
N = A.shape[1]
이산화
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
Sequential scan (학습 시 parallel scan 사용)
h = torch.zeros(B_batch, D, N, device=x.device)
ys = []
for t in range(L):
h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]
y = (h * C[:, t, None, :]).sum(-1) # (B, D)
ys.append(y)
return torch.stack(ys, dim=1) # (B, L, D)
실제 구현에서는 **CUDA 커널에서 SRAM에 모든 중간 상태를 유지**하여 HBM 접근을 최소화한다 (FlashAttention과 유사한 IO-aware 접근).
3.4 Mamba Block 아키텍처
Input
│
├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐
│ │
└──── Linear(D → ED) ──── SiLU ──────────────────────────── × ───────┘
│
Linear(ED → D)
│
Output
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,
padding=d_conv-1, groups=d_inner)
SSM 파라미터
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, Δ
self.dt_proj = nn.Linear(1, d_inner, bias=True)
A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
x: (B, L, D)
xz = self.in_proj(x) # (B, L, 2*ED)
x, z = xz.chunk(2, dim=-1)
Conv + SiLU
x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]
x = x.transpose(1, 2)
x = F.silu(x)
SSM
A = -torch.exp(self.A_log)
B_C_dt = self.x_proj(x)
... selective scan ...
Gate
y = y * F.silu(z)
return self.out_proj(y)
4. Mamba-2: State Space Duality
4.1 SSD (State Space Dual) 모델
Mamba-2(Dao & Gu, ICML 2024)는 SSM과 Attention 사이의 수학적 이중성(Duality)을 발견했다:
- **SSM 관점**: 재귀적 state update → O(N) 추론
- **Attention 관점**: 특수한 구조의 행렬 곱 → 병렬 학습
SSM recurrence: h_t = A_t h_{t-1} + B_t x_t
y_t = C_t h_t
↕ Dual ↕
Structured Attention: Y = (T ⊙ M) X
where T = lower-triangular causal mask
M = structured (semi-separable) matrix
4.2 성능 비교
| 모델 | 파라미터 | Pile (ppl) | 학습 FLOPS/s | 추론 (tok/s) |
|------|---------|------------|-------------|-------------|
| Transformer++ | 2.7B | 6.7 | 1.0x | 1.0x |
| Mamba-1 | 2.8B | 6.2 | 1.3x | 5.2x |
| Mamba-2 | 2.7B | 6.1 | **2.1x** | **5.5x** |
핵심 개선:
- 학습 속도 2배 향상 (structured matrix multiplication 활용)
- 더 큰 state size 가능 (N=64 → N=256)
- Multi-head 구조 도입
5. 실전: Mamba 모델 사용하기
5.1 설치 및 추론
pip install mamba-ssm causal-conv1d>=1.4.0
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
Mamba-2.8B 로드
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = model.cuda().half()
추론
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=100, temperature=0.7)
print(tokenizer.decode(output[0]))
5.2 Mamba-2 사용
from mamba_ssm import Mamba2
Mamba-2 레이어 단독 사용
layer = Mamba2(
d_model=2048,
d_state=128, # Mamba-2는 더 큰 state 가능
d_conv=4,
expand=2,
headdim=64, # Multi-head
).cuda()
x = torch.randn(1, 1024, 2048).cuda() # (batch, seq_len, d_model)
y = layer(x) # (1, 1024, 2048)
6. Mamba vs Transformer 비교
| 특성 | Transformer | Mamba |
|------|------------|-------|
| 학습 복잡도 | O(N²) | O(N) |
| 추론 복잡도 | O(N) per token (KV Cache) | O(1) per token |
| 메모리 (추론) | O(N) KV Cache | O(1) fixed state |
| In-context learning | ✅ 강함 | ⚠️ 약함 (개선 중) |
| 긴 시퀀스 | ❌ 병목 | ✅ 효율적 |
| 하드웨어 활용 | ✅ 병렬화 최적 | ⚠️ 재귀 병목 (Mamba-2 개선) |
7. 한계와 전망
7.1 현재 한계
- **In-context learning**: 고정 크기 state로 무한한 컨텍스트 정보 압축 시 손실
- **Retrieval 태스크**: 정확한 토큰 검색이 필요한 경우 Attention이 우위
- **학습 안정성**: 긴 시퀀스에서 gradient 불안정 문제
7.2 하이브리드 접근
최근 트렌드는 **Mamba + Attention 하이브리드**:
- **Jamba** (AI21): Mamba 레이어 + Attention 레이어 혼합
- **Zamba** (Zyphra): Mamba 기반 + 소수의 Shared Attention 레이어
8. 퀴즈
SSM 파라미터(B, C, Δ)를 **입력 의존적으로** 만들어 content-based reasoning을 가능하게 한 것. S4는 고정 파라미터여서 입력 내용에 따른 선택적 처리가 불가했다.
Step size로, **게이팅 메커니즘** 역할. Δ가 크면 이전 state 유지(입력 무시), 작으면 현재 입력에 집중(state 갱신). LSTM의 forget gate와 유사.
Convolution trick은 **시간 불변(time-invariant)** 시스템에서만 작동. Selection으로 파라미터가 입력 의존이 되면 시간 가변(time-varying)이 되어 글로벌 컨볼루션 불가.
CUDA 커널에서 **SRAM에 중간 state를 유지**하여 HBM 접근 최소화. FlashAttention과 유사한 IO-aware 접근으로 memory-bound 연산 최적화.
SSM의 재귀적 state update와 structured attention 행렬 곱이 **수학적으로 동일**하다는 것. 이를 통해 학습 시 병렬화된 행렬 곱, 추론 시 효율적 재귀를 모두 활용.
In-context learning과 정확한 토큰 검색(retrieval) 태스크. Attention은 시퀀스 내 임의 위치를 직접 참조할 수 있지만, Mamba는 고정 크기 state로 정보를 압축해야 한다.
Mamba 레이어의 O(N) 효율성 + 소수의 Attention 레이어의 정확한 검색 능력을 결합. 대부분의 레이어를 Mamba로, 핵심 위치에만 Attention을 배치.
현재 단락 (1/170)
Transformer 아키텍처는 2017년 등장 이후 NLP, Vision, Audio 등 거의 모든 시퀀스 모델링 분야를 지배해왔다. 하지만 근본적인 한계가 있다: