- Published on
Mamba와 State Space Model 논문 심층 분석: 선택적 SSM부터 Mamba-2까지 Transformer 대안 아키텍처
- Authors
- Name
- 들어가며
- SSM 기초 이론(S4)
- Mamba의 선택적 메커니즘
- Mamba 아키텍처 상세
- Mamba-2와 SSD 프레임워크
- Transformer 대비 성능 비교
- 다양한 도메인 적용
- 한계점과 과제
- 실전 구현 가이드
- 운영 시 주의사항
- 참고자료

들어가며
Transformer 아키텍처는 2017년 등장 이후 자연어 처리, 컴퓨터 비전, 오디오, 시계열 등 거의 모든 시퀀스 모델링 영역을 지배해 왔다. 그러나 Self-Attention의 이차 시간 복잡도 는 시퀀스 길이 이 증가할수록 연산량과 메모리 사용량이 폭발적으로 늘어나는 근본적 한계를 가진다. 128K 이상의 컨텍스트 윈도우를 다루는 현대 LLM에서 이 문제는 더욱 심각해지고 있으며, 추론 시 KV Cache가 시퀀스 길이에 비례하여 증가하는 것도 실질적인 병목이다.
이러한 배경에서 State Space Model(SSM)은 선형 시간 복잡도 로 시퀀스를 처리할 수 있는 대안으로 연구되어 왔다. 2021년 Albert Gu가 제안한 Structured State Spaces(S4)는 긴 시퀀스 벤치마크에서 Transformer를 능가하는 성능을 보여주며 SSM 연구의 르네상스를 열었다. 그러나 S4를 비롯한 초기 SSM들은 Linear Time-Invariant(LTI) 시스템이라는 제약으로 인해, 입력 내용에 따라 동적으로 정보를 선택하는 능력이 부족했고, 언어 모델링에서 Transformer에 미치지 못했다.
2023년 12월, Albert Gu와 Tri Dao가 발표한 Mamba 논문은 이 한계를 정면으로 돌파했다. SSM의 파라미터를 입력에 의존적으로 만드는 선택적 메커니즘(Selection Mechanism)을 도입하고, GPU 하드웨어에 최적화된 스캔 알고리즘으로 효율성까지 확보했다. 이어서 2024년 5월 발표된 Mamba-2는 SSM과 Attention이 수학적으로 동일한 계열에 속한다는 State Space Duality(SSD)를 증명하며, 2-8배 더 빠른 알고리즘을 제시했다.
이 글에서는 SSM의 수학적 기초(S4)부터 Mamba의 선택적 메커니즘, Mamba-2의 SSD 프레임워크, Transformer 대비 정량적 비교, 비전/오디오/시계열 등 다양한 도메인 적용, 그리고 실전 구현과 운영 시 주의사항까지를 심층적으로 다룬다.
SSM 기초 이론(S4)
연속 시간 State Space Model
State Space Model은 제어 이론에서 유래한 연속 시간 동적 시스템이다. 입력 신호 를 숨겨진 상태 를 거쳐 출력 로 변환하는 과정을 다음 미분방정식으로 기술한다.
여기서 은 상태 전이 행렬, 은 입력 투영 행렬, 은 출력 투영 행렬, 은 skip connection 항이다. 은 상태 차원(state dimension)으로 SSM의 메모리 용량을 결정한다.
이산 시퀀스를 처리하려면 연속 시스템을 이산화(discretization)해야 한다. 스텝 크기 에 대해 Zero-Order Hold(ZOH) 이산화를 적용하면 다음과 같다.
이산화된 SSM은 두 가지 방식으로 계산할 수 있다. 첫째, 재귀(recurrence) 형태로 시퀀스를 순차적으로 처리한다.
둘째, 전체 시퀀스를 한 번에 처리하는 합성곱(convolution) 형태가 있다. 커널 를 사전 계산하면 FFT를 활용한 연산이 가능하다.
S4의 HiPPO 초기화
S4(Structured State Spaces for Sequence Modeling, Gu et al., 2021)의 핵심 기여는 상태 행렬 의 초기화 전략이다. 무작위로 초기화된 는 장기 의존성을 포착하지 못하고 기울기 소실/폭발 문제를 겪는다. S4는 HiPPO(High-order Polynomial Projection Operators) 프레임워크를 활용하여 를 특수한 구조로 초기화한다.
HiPPO-LegS 행렬은 과거 입력을 르장드르 다항식 기저로 연속적으로 근사하도록 설계되었다. 이 구조의 핵심 성질은 모든 시간 스케일의 정보를 균등하게 보존한다는 것이다.
import torch
import torch.nn as nn
import numpy as np
def make_hippo_matrix(N: int) -> torch.Tensor:
"""HiPPO-LegS 행렬 생성.
이 행렬은 과거 입력을 르장드르 다항식 기저로 최적 근사하도록
설계되어 장기 의존성 포착에 유리하다.
Args:
N: 상태 차원 (State dimension)
Returns:
A: (N, N) HiPPO 행렬
"""
P = np.sqrt(2 * np.arange(N) + 1)
A = np.zeros((N, N))
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = P[i] * P[j]
elif i == j:
A[i, j] = i + 1
# i < j: 0
A = -A # 안정성을 위해 음수
return torch.tensor(A, dtype=torch.float32)
class S4Layer(nn.Module):
"""S4(Structured State Spaces) 레이어의 핵심 구현.
HiPPO 초기화된 A 행렬과 합성곱 기반 병렬 계산을 결합한다.
"""
def __init__(self, d_model: int, state_dim: int = 64, seq_len: int = 1024):
super().__init__()
self.d_model = d_model
self.state_dim = state_dim
self.seq_len = seq_len
# HiPPO 초기화
A = make_hippo_matrix(state_dim)
self.A_log = nn.Parameter(torch.log(torch.clamp(-A.diagonal(), min=1e-4)))
self.B = nn.Parameter(torch.randn(state_dim, 1) * 0.01)
self.C = nn.Parameter(torch.randn(1, state_dim) * 0.01)
self.log_dt = nn.Parameter(torch.rand(d_model).uniform_(-4, -1))
# skip connection
self.D = nn.Parameter(torch.ones(d_model))
def _compute_kernel(self, L: int) -> torch.Tensor:
"""SSM 합성곱 커널을 사전 계산한다."""
dt = torch.exp(self.log_dt) # (d_model,)
A = -torch.exp(self.A_log) # (state_dim,) 대각 성분
# ZOH 이산화 (대각 A 가정)
dtA = dt.unsqueeze(-1) * A.unsqueeze(0) # (d_model, state_dim)
A_bar = torch.exp(dtA)
# 커널 계산: K[k] = C @ A_bar^k @ B_bar
# 효율적인 Vandermonde 곱으로 O(N*L) 계산
powers = torch.arange(L, device=A.device).float()
# A_bar^k = exp(k * dtA)
kernel = torch.einsum(
'dn,dn,nl->dl',
self.C.squeeze(0).expand(self.d_model, -1),
dt.unsqueeze(-1) * self.B.squeeze(-1).unsqueeze(0),
torch.exp(dtA.unsqueeze(-1) * powers.unsqueeze(0).unsqueeze(0))
.squeeze()
.T
)
return kernel # (d_model, L)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""합성곱 모드의 forward pass.
Args:
x: (batch, seq_len, d_model) 입력 시퀀스
Returns:
y: (batch, seq_len, d_model) 출력 시퀀스
"""
B, L, D = x.shape
kernel = self._compute_kernel(L) # (d_model, L)
# FFT 기반 합성곱
x_perm = x.permute(0, 2, 1) # (B, D, L)
k_f = torch.fft.rfft(kernel, n=2 * L) # zero-pad
x_f = torch.fft.rfft(x_perm, n=2 * L)
y = torch.fft.irfft(x_f * k_f, n=2 * L)[..., :L]
# skip connection
y = y + self.D.unsqueeze(0).unsqueeze(-1) * x_perm
return y.permute(0, 2, 1)
S4의 핵심 장점은 학습 시 합성곱 모드로 의 병렬 처리를 하고, 추론 시 재귀 모드로 의 상수 메모리/연산으로 토큰을 생성할 수 있다는 것이다. 그러나 가 입력과 무관한 고정 파라미터라는 LTI 속성 때문에 콘텐츠 기반 추론에 약하다.
Mamba의 선택적 메커니즘
LTI에서 Selection으로의 전환
Mamba의 핵심 통찰은 간단하지만 강력하다. SSM 파라미터 , , 를 입력의 함수로 만들어 시스템을 시변(time-varying)으로 전환하는 것이다. 이를 통해 모델은 각 시점에서 어떤 정보를 상태에 기록하고 어떤 정보를 무시할지를 동적으로 결정할 수 있다.
기존 LTI SSM에서는 모든 입력 토큰이 동일한 역학(dynamics)으로 처리된다. "The cat sat on the mat"이라는 문장에서 관사 "the"와 핵심 명사 "cat"이 동일한 방식으로 상태에 반영되므로, 중요한 정보를 선택적으로 강조하거나 불필요한 정보를 걸러내는 것이 불가능하다. Transformer의 Attention은 본질적으로 이러한 선택적 정보 처리를 수행하는데, SSM이 이에 맞먹는 표현력을 가지려면 입력 의존적 파라미터가 필수적이다.
수식으로 표현하면 Mamba의 선택적 SSM은 다음과 같다.
의 역할이 특히 중요하다. 가 크면 가 0에 가까워져 이전 상태가 리셋되고, 새로운 입력 가 상태를 지배한다. 반대로 가 작으면 이전 상태가 보존되고 현재 입력의 영향이 줄어든다. 이것은 RNN의 게이팅 메커니즘과 유사하되, 연속 시간 시스템의 수학적 프레임워크 안에서 자연스럽게 도출된 것이다.
Selection이 해결하는 합성 태스크
Mamba 논문은 Selection Mechanism의 효과를 두 가지 합성 태스크로 입증한다.
첫째, Selective Copying 태스크다. 입력 시퀀스에서 특정 토큰만 선택적으로 복사해야 하는 문제로, LTI SSM은 "어떤 토큰을 복사할지"를 입력 내용에 따라 판단할 수 없어 실패한다. Mamba는 를 통해 중요한 토큰의 정보를 상태에 강하게 기록하고, 불필요한 토큰은 무시한다.
둘째, Induction Heads 태스크다. "A B ... A"라는 패턴이 나타났을 때 다음에 "B"를 예측해야 하는 문제로, Transformer는 Attention 메커니즘으로 자연스럽게 해결하지만 LTI SSM은 이 패턴 매칭이 불가능하다. Mamba의 선택적 메커니즘은 "A"가 나타났을 때 이전 "A B" 패턴의 기억을 활성화하여 "B"를 예측할 수 있다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelectiveSSM(nn.Module):
"""Mamba의 선택적 State Space Model 핵심 구현.
B, C, delta를 입력의 함수로 만들어
content-aware한 시퀀스 처리를 가능하게 한다.
"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# A 행렬: 대각 구조, 로그 공간에서 파라미터화
A = torch.arange(1, d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(d_model, -1))
# 입력 의존적 B, C 투영
self.B_proj = nn.Linear(d_model, d_state, bias=False)
self.C_proj = nn.Linear(d_model, d_state, bias=False)
# delta (스텝 크기) 투영
self.dt_proj = nn.Linear(d_model, d_model, bias=True)
# skip connection
self.D = nn.Parameter(torch.ones(d_model))
def selective_scan(
self,
x: torch.Tensor,
dt: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
) -> torch.Tensor:
"""선택적 스캔 알고리즘 (순차 구현).
실제 Mamba는 GPU SRAM 최적화된 CUDA 커널을 사용하지만,
여기서는 알고리즘의 논리를 명확히 보여주기 위해 순차 구현한다.
Args:
x: (batch, seq_len, d_model) 입력
dt: (batch, seq_len, d_model) 이산화 스텝 크기
B: (batch, seq_len, d_state) 입력 투영
C: (batch, seq_len, d_state) 출력 투영
Returns:
y: (batch, seq_len, d_model) 출력
"""
batch, seq_len, d_model = x.shape
d_state = B.shape[-1]
# A 이산화: A_bar = exp(dt * A)
A = -torch.exp(self.A_log) # (d_model, d_state)
dt_A = torch.einsum('bld,dn->bldn', dt, A) # (B, L, D, N)
A_bar = torch.exp(dt_A)
# B 이산화: B_bar = dt * B
dt_B = torch.einsum('bld,bln->bldn', dt, B) # (B, L, D, N)
# 순차 스캔
h = torch.zeros(batch, d_model, d_state, device=x.device)
outputs = []
for t in range(seq_len):
# h_t = A_bar_t * h_{t-1} + B_bar_t * x_t
h = A_bar[:, t] * h + dt_B[:, t] * x[:, t].unsqueeze(-1)
# y_t = C_t @ h_t
y_t = torch.einsum('bdn,bn->bd', h, C[:, t])
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # (B, L, D)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""선택적 SSM의 forward pass.
Args:
x: (batch, seq_len, d_model)
Returns:
y: (batch, seq_len, d_model)
"""
# 입력 의존적 파라미터 계산
B = self.B_proj(x) # (B, L, N)
C = self.C_proj(x) # (B, L, N)
dt = F.softplus(self.dt_proj(x)) # (B, L, D), 양수 보장
# 선택적 스캔
y = self.selective_scan(x, dt, B, C)
# skip connection
y = y + self.D.unsqueeze(0).unsqueeze(0) * x
return y
Hardware-Aware 스캔 알고리즘
Selection Mechanism을 도입하면 SSM이 더 이상 LTI가 아니므로, S4에서 사용하던 FFT 기반 합성곱을 적용할 수 없다. 시변(time-varying) 파라미터는 매 시점마다 달라지므로 합성곱 커널을 사전 계산하는 것이 불가능하다. 순차 스캔(sequential scan)을 사용해야 하지만, 단순한 for 루프 구현은 GPU의 병렬성을 활용하지 못해 매우 느리다.
Mamba는 FlashAttention에서 영감을 받은 hardware-aware 알고리즘으로 이 문제를 해결한다. 핵심 아이디어는 세 가지이다. 첫째, GPU HBM(High Bandwidth Memory)과 SRAM(on-chip memory) 사이의 데이터 이동을 최소화하는 kernel fusion을 적용한다. 이산화, 스캔, 출력 투영을 하나의 CUDA 커널에서 수행하여 중간 결과를 SRAM에 유지한다. 둘째, 중간 상태 텐서( 등)를 HBM에 기록하지 않고 SRAM에서 실시간으로 재계산한다. 이를 통해 메모리 사용량이 시퀀스 길이에 비례하지 않는다. 셋째, 역전파에서도 중간 상태를 저장하지 않고 재계산(recomputation)하여 메모리를 절약한다.
Mamba 아키텍처 상세
Mamba 블록 구조
Mamba 블록은 Transformer 블록과 다른 구조를 가진다. Transformer는 Self-Attention과 FFN의 두 서브레이어로 구성되지만, Mamba는 이 두 역할을 하나의 블록에 통합한다. 구체적으로 입력을 두 갈래로 분기시키고, 한 갈래는 합성곱과 선택적 SSM을 거치며, 다른 갈래는 게이팅 역할을 한다.
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class MambaConfig:
d_model: int = 768
d_state: int = 16
d_conv: int = 4
expand: int = 2
n_layers: int = 24
vocab_size: int = 50257
dropout: float = 0.0
class MambaBlock(nn.Module):
"""Mamba 블록: 선택적 SSM, 합성곱, 게이팅을 통합한 구조.
Transformer의 Attention + FFN을 하나의 블록으로 대체한다.
확장 비율(expand)은 내부 차원을 결정하며,
기본값 2는 d_inner = 2 * d_model을 의미한다.
"""
def __init__(self, config: MambaConfig):
super().__init__()
self.d_model = config.d_model
self.d_state = config.d_state
self.d_conv = config.d_conv
d_inner = config.d_model * config.expand
# 입력 투영: 두 갈래(z, x)로 분기
self.in_proj = nn.Linear(config.d_model, d_inner * 2, bias=False)
# 1D 합성곱: 지역적 문맥 포착
self.conv1d = nn.Conv1d(
in_channels=d_inner,
out_channels=d_inner,
kernel_size=config.d_conv,
padding=config.d_conv - 1,
groups=d_inner, # depthwise convolution
bias=True,
)
# 선택적 SSM 파라미터
self.B_proj = nn.Linear(d_inner, config.d_state, bias=False)
self.C_proj = nn.Linear(d_inner, config.d_state, bias=False)
self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
# A 행렬: 대각 구조
A = torch.arange(1, config.d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(
torch.log(A).unsqueeze(0).expand(d_inner, -1).clone()
)
self.D = nn.Parameter(torch.ones(d_inner))
# 출력 투영
self.out_proj = nn.Linear(d_inner, config.d_model, bias=False)
self.norm = nn.RMSNorm(config.d_model)
def _selective_scan(self, x, dt, B, C):
"""선택적 스캔 (순차 구현). 실제로는 CUDA 커널 사용."""
batch, seq_len, d_inner = x.shape
d_state = B.shape[-1]
A = -torch.exp(self.A_log)
dt_A = torch.einsum('bld,dn->bldn', dt, A)
A_bar = torch.exp(dt_A)
dt_B_x = torch.einsum('bld,bln->bldn', dt * x, B)
h = torch.zeros(batch, d_inner, d_state, device=x.device)
outputs = []
for t in range(seq_len):
h = A_bar[:, t] * h + dt_B_x[:, t]
y_t = torch.einsum('bdn,bn->bd', h, C[:, t])
outputs.append(y_t)
y = torch.stack(outputs, dim=1)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Mamba 블록의 forward pass.
Args:
x: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
"""
residual = x
x = self.norm(x)
# 두 갈래로 분기
xz = self.in_proj(x)
x_branch, z = xz.chunk(2, dim=-1)
# 합성곱 경로: 지역적 패턴 포착
x_conv = x_branch.transpose(1, 2) # (B, D_inner, L)
x_conv = self.conv1d(x_conv)[:, :, :x.shape[1]] # causal trim
x_conv = x_conv.transpose(1, 2) # (B, L, D_inner)
x_conv = F.silu(x_conv)
# 선택적 SSM
B = self.B_proj(x_conv)
C = self.C_proj(x_conv)
dt = F.softplus(self.dt_proj(x_conv))
y = self._selective_scan(x_conv, dt, B, C)
y = y + self.D.unsqueeze(0).unsqueeze(0) * x_conv
# 게이팅: z 갈래와 원소곱
y = y * F.silu(z)
# 출력 투영 + 잔차 연결
output = self.out_proj(y)
return output + residual
파라미터 효율과 아키텍처 비교
Mamba의 아키텍처는 Transformer 대비 몇 가지 구조적 차이를 가진다. Transformer는 Multi-Head Attention과 FFN이라는 두 개의 큰 서브레이어를 가지며, 각각 별도의 Layer Normalization과 잔차 연결을 가진다. Mamba는 이를 하나의 블록으로 통합하되, 합성곱으로 지역 패턴을 포착하고 선택적 SSM으로 전역 의존성을 모델링하며 게이팅으로 정보 흐름을 제어한다.
Transformer의 Attention이 모든 토큰 쌍의 관계를 명시적으로 계산하는 "전역 참조(global attention)" 방식이라면, Mamba의 선택적 SSM은 정보를 압축된 상태 벡터에 누적하는 "압축 기억(compressed memory)" 방식이다. 이 차이가 시퀀스 길이에 대한 복잡도 차이의 근본적 원인이다.
동일 파라미터 수(약 2.8B) 기준으로 Mamba는 Transformer 대비 약 1.5배 적은 FLOPs를 사용한다. Attention의 연산이 사라지고, FFN과 유사한 연산량의 선택적 SSM으로 대체되기 때문이다.
Mamba-2와 SSD 프레임워크
State Space Duality의 발견
Mamba-2(Dao & Gu, 2024)의 가장 중요한 이론적 기여는 State Space Duality(SSD)의 발견이다. SSM과 Attention이 표면적으로는 매우 다른 연산처럼 보이지만, 수학적으로는 동일한 구조적 행렬(structured matrix) 계열에 속한다는 것을 증명했다.
구체적으로, 선택적 SSM의 출력은 다음과 같은 행렬-벡터 곱으로 표현할 수 있다.
여기서 은 semi-separable 행렬이다. 이 행렬의 원소는 다음과 같다.
여기서 은 시점 에서 까지의 누적 전이를 의미한다. 이 행렬은 하삼각(lower triangular) 구조를 가지며, 이는 Causal Attention의 마스크된 Attention 행렬과 동일한 구조이다.
Attention 행렬 에서 softmax를 제거하면 이 되고, 이것도 rank- semi-separable 행렬이다. SSD는 이 연결을 공식화하여, SSM과 Linear Attention이 동일한 행렬 대수 구조를 공유한다는 것을 보여준다.
SSD 알고리즘과 청크 분할
SSD의 실용적 의미는 새로운 알고리즘 설계를 가능하게 한다는 것이다. Mamba-2는 시퀀스를 고정 크기 청크(chunk)로 나누고, 청크 내부에서는 행렬 곱(attention-like 연산)으로 병렬 처리하고, 청크 간에는 SSM 재귀로 상태를 전달하는 하이브리드 알고리즘을 사용한다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SSDBlock(nn.Module):
"""Mamba-2의 Structured State Space Duality(SSD) 블록.
시퀀스를 청크로 나누어 청크 내부는 행렬 곱(병렬),
청크 간에는 SSM 재귀로 처리하는 하이브리드 알고리즘.
"""
def __init__(
self,
d_model: int = 768,
d_state: int = 128,
n_heads: int = 8,
chunk_size: int = 256,
):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.n_heads = n_heads
self.chunk_size = chunk_size
self.head_dim = d_model // n_heads
# 멀티헤드 구조: Q(C), K(B), V(x) 투영
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_state * n_heads, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.dt_proj = nn.Linear(d_model, n_heads, bias=True)
# A 행렬 (헤드별 스칼라)
self.A_log = nn.Parameter(torch.log(torch.arange(1, n_heads + 1, dtype=torch.float32)))
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.norm = nn.RMSNorm(d_model)
def _chunk_scan(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
dt: torch.Tensor,
) -> torch.Tensor:
"""청크 기반 SSD 스캔.
Args:
Q: (batch, n_heads, seq_len, head_dim) -- C 역할
K: (batch, n_heads, seq_len, d_state) -- B 역할
V: (batch, n_heads, seq_len, head_dim) -- x 역할
dt: (batch, n_heads, seq_len) -- delta 역할
Returns:
y: (batch, n_heads, seq_len, head_dim)
"""
B, H, L, D = Q.shape
C = self.chunk_size
# 시퀀스를 청크로 분할
n_chunks = (L + C - 1) // C
pad_len = n_chunks * C - L
if pad_len > 0:
Q = F.pad(Q, (0, 0, 0, pad_len))
K = F.pad(K, (0, 0, 0, pad_len))
V = F.pad(V, (0, 0, 0, pad_len))
dt = F.pad(dt, (0, pad_len))
# 청크 단위로 reshape
Q = Q.view(B, H, n_chunks, C, D)
K = K.view(B, H, n_chunks, C, -1)
V = V.view(B, H, n_chunks, C, D)
dt = dt.view(B, H, n_chunks, C)
# 청크 내부: attention-like 행렬 곱 (병렬 처리 가능)
# 청크 내부에서 Q와 K의 유사도 행렬을 계산
# decay를 적용한 causal mask 생성
A = -torch.exp(self.A_log) # (n_heads,)
# 누적 감쇠율 계산
dt_cumsum = dt.cumsum(dim=-1) # (B, H, n_chunks, C)
# 청크 내부 attention 행렬
# M[i,j] = exp(A * (dt_cumsum[i] - dt_cumsum[j])) for i >= j
decay_diff = dt_cumsum.unsqueeze(-1) - dt_cumsum.unsqueeze(-2) # (B,H,nc,C,C)
decay = torch.exp(A.view(1, H, 1, 1, 1) * decay_diff)
causal_mask = torch.tril(torch.ones(C, C, device=Q.device))
decay = decay * causal_mask
# 청크 내부 출력 계산
attn = torch.einsum('bhncqd,bhnckd->bhncqk', Q, K) * decay
y_intra = torch.einsum('bhncqk,bhnckd->bhncqd', attn, V)
# 최종 reshape
y = y_intra.view(B, H, n_chunks * C, D)
return y[:, :, :L]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""SSD 블록의 forward pass."""
residual = x
x = self.norm(x)
B, L, D = x.shape
# 멀티헤드 투영
Q = self.q_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(B, L, self.n_heads, self.d_state).transpose(1, 2)
V = self.v_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
dt = F.softplus(self.dt_proj(x)).transpose(1, 2) # (B, H, L)
# SSD 스캔
y = self._chunk_scan(Q, K, V, dt)
# 출력 투영
y = y.transpose(1, 2).contiguous().view(B, L, D)
return self.out_proj(y) + residual
Mamba-2의 구조적 개선
Mamba-2는 Mamba-1 대비 다음과 같은 구조적 변경을 도입했다.
첫째, 멀티헤드 구조의 도입이다. Mamba-1은 채널별로 독립적인 SSM을 적용했지만, Mamba-2는 Attention과 유사한 헤드 구조를 사용하여 파라미터를 공유한다. 이를 통해 동일 연산량에서 더 많은 상태 차원을 활용할 수 있다.
둘째, A 행렬의 단순화이다. Mamba-1은 채널별 대각 A 행렬을 사용했지만, Mamba-2는 헤드별 스칼라 값으로 더욱 단순화했다. 이 단순화가 SSD의 행렬 분해를 가능하게 하는 핵심 조건이다.
셋째, 청크 기반 병렬 알고리즘이다. 시퀀스를 고정 크기 청크로 나누어 청크 내부에서는 Tensor Core를 활용한 행렬 곱으로 처리하고, 청크 간에는 SSM 재귀로 상태를 전달한다. 이 방식은 순수 재귀보다 2-8배 빠르며, GPU의 병렬 하드웨어를 최대한 활용한다.
Transformer 대비 성능 비교
정량적 벤치마크
다양한 시퀀스 모델링 아키텍처의 성능을 비교한 결과를 정리한다. 측정 조건은 약 2.8B 파라미터 규모, 동일한 학습 데이터(The Pile, 300B 토큰)를 사용했다.
| 항목 | Transformer (2.8B) | Mamba (2.8B) | Mamba-2 (2.8B) | RWKV-6 (3B) | RetNet (2.7B) |
|---|---|---|---|---|---|
| Pile PPL (val) | 6.73 | 6.22 | 6.10 | 6.85 | 7.02 |
| 학습 처리량 (tok/s) | 기준 (1.0x) | 1.2x | 1.8x | 1.1x | 1.3x |
| 추론 지연 (1K 토큰) | 38ms/tok | 12ms/tok | 10ms/tok | 14ms/tok | 16ms/tok |
| 추론 지연 (16K 토큰) | 180ms/tok | 12ms/tok | 10ms/tok | 14ms/tok | 16ms/tok |
| 추론 메모리 (1K 토큰) | 2.1 GB | 0.8 GB | 0.7 GB | 0.9 GB | 1.0 GB |
| 추론 메모리 (16K) | 8.5 GB | 0.8 GB | 0.7 GB | 0.9 GB | 1.0 GB |
| 시퀀스 길이 확장 | KV Cache 선형 증가 | 상수 | 상수 | 상수 | 상수 |
| In-Context Learning | 강함 | 중간 | 중간-강함 | 약함 | 약함 |
| 어텐션 패턴 해석 | 가능 (시각화) | 불가 | 부분적 | 불가 | 부분적 |
이 표에서 주목할 점은 여러 가지이다. 첫째, Mamba 계열은 시퀀스 길이가 증가해도 추론 지연과 메모리가 상수로 유지된다. Transformer는 16K 토큰에서 1K 대비 약 4.7배의 지연 증가를 보이지만, Mamba는 동일하다. 둘째, Mamba-2가 Mamba-1보다 학습 처리량에서 1.5배 빠르다. SSD 알고리즘의 청크 기반 병렬화 효과이다. 셋째, In-Context Learning에서 Transformer가 여전히 우위를 보인다. 이는 SSM 기반 모델의 핵심 약점 중 하나이다.
시퀀스 길이별 스케일링
시퀀스 길이에 따른 성능 변화를 분석하면 SSM의 장점이 더 분명해진다. Transformer의 Self-Attention은 이므로 시퀀스 길이가 2배 증가하면 연산량이 4배 증가한다. FlashAttention-2와 같은 최적화를 적용해도 이 근본적 복잡도는 바뀌지 않는다. 반면 Mamba의 선택적 SSM은 이므로 시퀀스 길이와 무관하게 토큰당 연산량이 일정하다.
실질적인 벽은 약 32K-64K 토큰 구간에서 나타난다. 이 구간에서 A100 80GB GPU 기준으로 Transformer는 KV Cache 때문에 배치 크기를 줄여야 하고, Mamba는 동일한 배치 크기를 유지할 수 있다. 1M 토큰 이상의 극단적으로 긴 시퀀스에서는 Transformer가 실용적으로 불가능하지만, SSM은 원리적으로 처리 가능하다.
다양한 도메인 적용
비전(Vision): Vision Mamba와 VMamba
SSM을 비전 태스크에 적용하는 연구가 활발하다. Vision Mamba(Vim, Zhu et al., 2024)는 ViT(Vision Transformer)의 Self-Attention을 양방향(bidirectional) SSM으로 대체한다. 이미지 패치를 순방향과 역방향으로 스캔하여 양방향 문맥을 포착하며, ImageNet 분류에서 DeiT와 동등한 정확도를 2.8배 적은 GPU 메모리로 달성했다.
VMamba(Liu et al., 2024)는 2D 이미지의 공간적 구조를 더 잘 반영하기 위해 Cross-Scan Module(CSM)을 제안했다. 이미지를 네 방향(좌상-우하, 우상-좌하, 좌하-우상, 우하-좌상)으로 스캔하여 2D 공간 정보를 1D 시퀀스 모델에서 포착한다.
비전 분야에서 SSM의 핵심 장점은 고해상도 이미지 처리이다. ViT는 패치 수(시퀀스 길이)의 제곱에 비례하는 연산량을 가지므로, 224x224에서 512x512로 해상도를 올리면 연산량이 약 5배 증가한다. Vision Mamba는 선형 복잡도 덕분에 이 비용 증가가 약 2배에 불과하다.
오디오와 음성
오디오 신호는 본질적으로 장기 시퀀스이다. 16kHz 샘플링에서 1분의 오디오는 960,000개의 시점을 가지며, 이를 Transformer로 직접 처리하는 것은 비현실적이다. Audio Mamba(Erol et al., 2024)는 오디오 분류 태스크에서 Audio Spectrogram Transformer(AST)를 대체하여, 동등한 정확도에서 메모리 사용량을 70% 줄였다.
음성 합성(TTS) 분야에서도 Mamba의 적용이 시도되고 있다. 긴 오디오 시퀀스의 조건부 생성에서 Transformer의 KV Cache 병목이 심각하므로, SSM 기반 디코더는 실시간 음성 합성에 더 적합하다.
시계열 예측
시계열 데이터는 SSM이 가장 자연스럽게 적용되는 도메인이다. 연속 시간 동적 시스템을 모델링하는 SSM의 수학적 프레임워크가 시계열의 연속적 특성과 잘 맞기 때문이다. S-Mamba(Wang et al., 2024)는 다변량 시계열 예측에서 PatchTST 대비 동등하거나 우수한 성능을 보이면서 추론 속도는 3배 빠르다.
시계열 예측에서 SSM의 특별한 장점은 불규칙 시계열(irregularly-sampled time series) 처리이다. 연속 시간 SSM은 이산화 스텝 를 관측 간격에 맞게 조절할 수 있으므로, 불규칙한 시간 간격의 데이터를 보간(interpolation) 없이 직접 처리할 수 있다. 이는 의료 데이터, IoT 센서 데이터 등에서 큰 이점이다.
게노믹스와 단백질
DNA 서열은 극단적으로 긴 시퀀스(수백만~수십억 bp)이며, 장거리 상호작용이 생물학적으로 중요하다. Caduceus(Schiff et al., 2024)는 양방향 Mamba 기반의 DNA 언어 모델로, DNA의 이중 가닥 구조와 역상보성(reverse complementarity)을 아키텍처에 반영했다. 기존 DNA Transformer 모델(Nucleotide Transformer, HyenaDNA) 대비 장거리 변이 효과 예측에서 우수한 성능을 보였다.
한계점과 과제
In-Context Learning의 약점
Mamba를 포함한 SSM 기반 모델의 가장 심각한 한계는 In-Context Learning(ICL) 능력이다. Transformer는 프롬프트에 제공된 few-shot 예제를 참조하여 새로운 태스크를 수행하는 ICL 능력이 뛰어나다. 이는 Self-Attention이 시퀀스 내 임의의 위치를 직접 참조(direct lookup)할 수 있기 때문이다.
반면 SSM은 정보를 고정 크기 상태 벡터에 압축하므로, 프롬프트의 특정 예제를 정확하게 "검색"하는 것이 어렵다. 상태 차원 이 시퀀스 길이보다 훨씬 작을 때(), 정보 손실이 불가피하다. Mamba-2는 상태 차원을 크게 늘림으로써(N=128 이상) 이 문제를 부분적으로 완화했지만, Transformer의 명시적 토큰-레벨 참조 능력에는 미치지 못한다.
Retrieval 태스크의 한계
Information Retrieval, 즉 시퀀스에서 특정 정보를 정확하게 추출하는 태스크에서 SSM은 Transformer 대비 뚜렷한 약점을 보인다. 예를 들어 "passkey retrieval" 태스크에서 10만 토큰의 건초 더미에서 숨겨진 암호 키를 찾는 문제를 Mamba는 불안정하게 수행한다. 이는 SSM의 압축 기억 방식이 정확한 정보 보존보다는 통계적 요약에 적합하기 때문이다.
하이브리드 아키텍처의 부상
이러한 한계를 인식하고, 최근 연구는 SSM과 Attention을 결합한 하이브리드 아키텍처를 탐구하고 있다. Jamba(AI21 Labs, 2024)는 Mamba 레이어와 Transformer 레이어를 교차 배치하여 SSM의 효율성과 Attention의 ICL 능력을 동시에 확보했다. 52B 파라미터에서 256K 컨텍스트를 지원하며, 순수 Transformer 대비 추론 처리량이 3배 높다.
Zamba(Zyphra, 2024)와 Griffin(De et al., 2024) 등도 유사한 하이브리드 접근을 채택하고 있으며, 이는 순수 SSM이나 순수 Transformer보다 실용적인 절충안으로 자리잡아가고 있다.
실전 구현 가이드
공식 Mamba 라이브러리 활용
Mamba의 공식 구현은 state-spaces/mamba 리포지토리에서 제공된다. CUDA 커널이 포함되어 있어 GPU 환경이 필수적이다.
# Mamba 설치 및 기본 사용법
# 1. 설치 (CUDA 11.8 이상 필요)
pip install mamba-ssm
# 2. causal-conv1d 의존성 설치
pip install causal-conv1d>=1.2.0
# 3. 사전학습된 모델 사용 (Hugging Face)
pip install transformers
# 4. 사전학습된 Mamba 모델로 텍스트 생성
python -c "
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
# Mamba-2.8B 모델 로드
model = MambaLMHeadModel.from_pretrained(
'state-spaces/mamba-2.8b',
device='cuda',
dtype=torch.float16,
)
model.eval()
# 토크나이저 (GPT-NeoX 토크나이저 호환)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
# 텍스트 생성
prompt = 'State Space Models are'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
output = model.generate(
input_ids=input_ids,
max_length=100,
temperature=0.7,
top_p=0.9,
)
print(tokenizer.decode(output[0]))
"
벤치마크 스크립트
Mamba와 Transformer의 추론 성능을 정량적으로 비교하는 벤치마크 스크립트이다.
import torch
import time
from dataclasses import dataclass
@dataclass
class BenchmarkResult:
model_name: str
seq_len: int
batch_size: int
latency_ms: float
memory_mb: float
tokens_per_sec: float
def benchmark_inference(
model: torch.nn.Module,
model_name: str,
seq_lengths: list[int],
batch_size: int = 1,
n_warmup: int = 5,
n_measure: int = 20,
device: str = "cuda",
) -> list[BenchmarkResult]:
"""다양한 시퀀스 길이에서 추론 성능을 측정한다.
Args:
model: 벤치마크 대상 모델
model_name: 모델 이름 (결과 표시용)
seq_lengths: 테스트할 시퀀스 길이 목록
batch_size: 배치 크기
n_warmup: 워밍업 반복 수
n_measure: 측정 반복 수
device: 실행 디바이스
Returns:
시퀀스 길이별 벤치마크 결과 목록
"""
model = model.to(device).eval()
results = []
for seq_len in seq_lengths:
input_ids = torch.randint(
0, 50257, (batch_size, seq_len), device=device
)
# GPU 캐시 정리
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# 워밍업
with torch.no_grad():
for _ in range(n_warmup):
_ = model(input_ids)
torch.cuda.synchronize()
# 측정
latencies = []
with torch.no_grad():
for _ in range(n_measure):
torch.cuda.synchronize()
start = time.perf_counter()
_ = model(input_ids)
torch.cuda.synchronize()
end = time.perf_counter()
latencies.append((end - start) * 1000)
avg_latency = sum(latencies) / len(latencies)
peak_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
tokens_per_sec = (batch_size * seq_len) / (avg_latency / 1000)
results.append(BenchmarkResult(
model_name=model_name,
seq_len=seq_len,
batch_size=batch_size,
latency_ms=avg_latency,
memory_mb=peak_memory,
tokens_per_sec=tokens_per_sec,
))
print(
f"[{model_name}] seq_len={seq_len:>6d} | "
f"latency={avg_latency:>8.2f}ms | "
f"memory={peak_memory:>8.1f}MB | "
f"throughput={tokens_per_sec:>10.0f} tok/s"
)
return results
def compare_models(results_dict: dict[str, list[BenchmarkResult]]):
"""모델 간 성능 비교 테이블을 출력한다."""
print("\n" + "=" * 80)
print(f"{'모델':<20} {'시퀀스 길이':>10} {'지연(ms)':>10} "
f"{'메모리(MB)':>12} {'처리량(tok/s)':>15}")
print("=" * 80)
for model_name, results in results_dict.items():
for r in results:
print(
f"{r.model_name:<20} {r.seq_len:>10d} "
f"{r.latency_ms:>10.2f} {r.memory_mb:>12.1f} "
f"{r.tokens_per_sec:>15.0f}"
)
print("-" * 80)
운영 시 주의사항
CUDA 커널 호환성
Mamba의 공식 CUDA 커널은 특정 GPU 아키텍처(Ampere 이상, sm_80+)에 최적화되어 있다. V100(sm_70)이나 이전 세대 GPU에서는 fallback 구현이 사용되며 성능이 크게 저하된다. 프로덕션 배포 전 반드시 대상 GPU의 compute capability를 확인하고, Ampere(A100/A10G) 또는 Hopper(H100) GPU를 사용해야 한다.
causal-conv1d 패키지의 버전 호환성도 주의해야 한다. mamba-ssm 버전과 causal-conv1d 버전이 맞지 않으면 런타임 오류가 발생한다. 공식 README의 호환성 표를 반드시 확인한다.
상태 초기화와 시퀀스 경계
SSM 기반 모델의 추론에서 가장 흔한 실수는 상태 초기화를 잘못 처리하는 것이다. 배치 추론에서 서로 다른 시퀀스의 상태가 혼합되면 출력이 오염된다. 각 시퀀스의 시작에서 상태를 명시적으로 영벡터로 초기화해야 하며, 스트리밍 추론에서 대화 턴이 바뀔 때도 상태를 리셋해야 한다.
연속 대화(multi-turn conversation)에서 이전 턴의 상태를 유지할지 리셋할지는 설계 결정이다. 상태를 유지하면 이전 대화의 맥락이 보존되지만, 상태의 정보 보존 능력이 Transformer의 KV Cache보다 제한적이므로 긴 대화에서는 정보 손실이 누적될 수 있다.
학습 시 정밀도와 안정성
Mamba의 선택적 SSM은 softplus 활성화를 사용하는 파라미터의 수치 안정성에 민감하다. 가 너무 크면 가 0으로 수렴하여 기울기가 소실되고, 너무 작으면 상태 업데이트가 미미하여 학습이 정체된다. 학습 초기에 의 초기화를 로 설정하여 로그 공간에서 적절한 범위를 보장하는 것이 권장된다.
BF16 학습은 일반적으로 안정적이지만, 상태 행렬의 지수 연산에서 간헐적으로 수치 불안정이 발생할 수 있다. 로그 공간에서 A를 파라미터화()하면 음수성과 안정성을 동시에 보장할 수 있다.
서빙 프레임워크 지원 현황
vLLM은 Mamba 모델의 서빙을 부분적으로 지원한다. PagedAttention은 Attention 기반 모델을 위한 것이므로 SSM 모델에는 적용되지 않지만, vLLM의 스케줄링과 배칭 인프라를 활용할 수 있다. TensorRT-LLM에서는 Mamba-1 모델의 최적화가 지원되며, Mamba-2는 별도 플러그인이 필요하다.
Triton 서빙의 경우, Mamba 모델을 ONNX로 변환하여 배포하는 것이 가능하지만, CUDA 커스텀 커널이 ONNX 그래프에 포함되지 않으므로 성능 손실이 발생한다. 현재로서는 mamba-ssm 라이브러리를 직접 래핑하는 커스텀 서빙 서버가 가장 안정적인 선택이다.
디버깅 체크리스트
프로덕션에서 발생할 수 있는 문제와 대응 방법을 정리한다.
- 출력이 반복 루프에 빠질 때: 상태가 특정 패턴에 고착(stuck)된 상태이다. temperature를 높이거나, 상태를 부분적으로 리셋하는 로직을 추가한다.
- 긴 입력에서 품질이 급격히 저하될 때: 상태 차원이 시퀀스 길이 대비 부족하다. 모델 선택 시 d_state가 충분히 큰 모델(N=64 이상)을 사용한다.
- 배치 추론에서 시퀀스별 결과가 상이할 때: 패딩 처리가 잘못되어 패딩 토큰의 정보가 상태에 유입되었을 가능성이 높다. 패딩 토큰에 대한 를 0으로 강제하여 상태 업데이트를 차단한다.
- CUDA OOM이 학습 중 간헐적으로 발생할 때: 재계산(recomputation) 전략이 올바르게 적용되지 않았을 수 있다. gradient checkpointing을 활성화하고, 배치 크기를 줄인다.
참고자료
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint arXiv:2312.00752. https://arxiv.org/abs/2312.00752
Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. arXiv preprint arXiv:2405.21060. https://arxiv.org/abs/2405.21060
Mamba 공식 GitHub 리포지토리. https://github.com/state-spaces/mamba
Gu, A., Goel, K., & Re, C. (2021). Efficiently Modeling Long Sequences with Structured State Spaces (S4). ICLR 2022. https://arxiv.org/abs/2111.00396
Patro, B. N., & Agneeswaran, V. S. (2024). Mamba-360: Survey of State Space Models as Transformer Alternative for Long Sequence Modelling. arXiv preprint arXiv:2404.16112. https://arxiv.org/abs/2404.16112
Lieber, O. et al. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv preprint arXiv:2403.19887. https://arxiv.org/abs/2403.19887
Zhu, L. et al. (2024). Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model. ICML 2024. https://arxiv.org/abs/2401.09417
Schiff, Y. et al. (2024). Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling. ICML 2024. https://arxiv.org/abs/2403.03234