- Authors

- Name
- Youngju Kim
- @fjvbn20031
개요
2017년 "Attention is All You Need" 논문 이후, Transformer는 자연어 처리부터 이미지, 코드, 오디오까지 거의 모든 시퀀스 모델링 태스크를 지배해왔습니다. 그러나 Transformer는 근본적인 약점이 있습니다. 바로 셀프 어텐션의 이차 복잡도(O(n^2)) 문제입니다.
2023년 Albert Gu와 Tri Dao가 발표한 Mamba는 이 문제를 우아하게 해결하며 딥러닝 커뮤니티에 큰 충격을 주었습니다. Mamba는 상태 공간 모델(State Space Model, SSM)이라는 고전 제어 이론에서 온 개념을 현대 딥러닝에 접목하여, 선형 시간 복잡도로 긴 시퀀스를 처리할 수 있게 했습니다.
이 가이드에서는 SSM의 수학적 기초부터 Mamba의 핵심 혁신, 그리고 실전 구현까지 완전히 다룹니다.
1. Transformer의 한계와 SSM의 등장
셀프 어텐션의 O(n^2) 복잡도 문제
Transformer의 핵심은 셀프 어텐션(Self-Attention) 메커니즘입니다. 길이 n인 시퀀스에 대해 어텐션 행렬을 계산할 때, 모든 토큰 쌍의 관계를 계산하므로 시간과 메모리 복잡도가 **O(n^2)**이 됩니다.
어텐션 행렬 크기: n × n
= 1,000 토큰 → 1,000,000 원소
= 10,000 토큰 → 100,000,000 원소
= 100,000 토큰 → 10,000,000,000 원소 (불가능!)
이 문제로 인해 실제 LLM들은 컨텍스트 윈도우 크기에 제약을 받습니다. GPT-4는 128K 토큰을 지원하지만, 이를 위해 엄청난 계산 비용과 최적화 기법이 필요합니다.
긴 시퀀스 처리의 어려움
긴 문서 요약, 코드 베이스 이해, 장시간 대화 유지 등 실제 애플리케이션에서는 매우 긴 컨텍스트가 필요합니다. 하지만 Transformer의 이차 복잡도는 실용적 사용을 어렵게 만듭니다.
예를 들어, 전체 책 한 권(약 100,000 단어 = 130,000 토큰)을 한 번에 처리하려면, 표준 Transformer는 현재 하드웨어로 처리 불가능에 가까운 메모리를 요구합니다.
순환 모델(RNN, LSTM)의 한계
순환 신경망(RNN)과 LSTM은 이론적으로 O(n) 복잡도를 가집니다. 각 타임스텝에서 고정된 크기의 숨겨진 상태(hidden state)를 유지하기 때문입니다. 그러나 RNN/LSTM은 다음 문제가 있습니다:
- 기울기 소실/폭발 (Vanishing/Exploding Gradient): 긴 시퀀스에서 역전파 시 기울기가 사라지거나 폭발
- 병렬화 불가능: 각 타임스텝이 이전 상태에 의존하므로 GPU 병렬화가 어려움
- 장기 의존성 포착 어려움: 멀리 떨어진 정보를 연결하기 어려움
SSM이 제시하는 해결책
상태 공간 모델(SSM)은 이 두 접근법의 장점을 결합합니다:
- 훈련 시: 합성곱(Convolution)으로 완전 병렬화 → Transformer에 버금가는 훈련 효율
- 추론 시: 순환(Recurrence)으로 O(1) 메모리와 O(n) 계산 → RNN처럼 효율적인 추론
- 긴 시퀀스: 선형 복잡도로 매우 긴 시퀀스 처리 가능
2. 상태 공간 모델(SSM) 수학적 기초
연속 시간 SSM
SSM의 기원은 1960년대 Rudolf Kalman의 제어 이론으로 거슬러 올라갑니다. 연속 시간 선형 시스템은 다음과 같이 표현됩니다:
x'(t) = A·x(t) + B·u(t)
y(t) = C·x(t) + D·u(t)
여기서:
u(t): 입력 신호 (input)x(t): 상태 벡터 (latent state), 크기 Ny(t): 출력 신호 (output)A: N×N 상태 전이 행렬 (state transition matrix)B: N×1 입력 행렬 (input projection)C: 1×N 출력 행렬 (output projection)D: 직접 전달 항 (skip connection, 보통 0으로 설정)
이 시스템은 입력 u(t)를 받아 내부 상태 x(t)를 업데이트하고 출력 y(t)를 생성합니다. 상태 x(t)는 과거 정보의 "요약"이라고 볼 수 있습니다.
이산화 (Discretization)
연속 시간 SSM을 디지털 컴퓨터에서 사용하려면 이산화(discretization)가 필요합니다. 샘플링 간격을 Delta라 할 때, 두 가지 주요 이산화 방법이 있습니다.
Zero-Order Hold (ZOH):
A_bar = exp(Delta · A)
B_bar = (Delta · A)^(-1) · (exp(Delta · A) - I) · Delta · B
Bilinear (Tustin) 변환:
A_bar = (I - Delta/2 · A)^(-1) · (I + Delta/2 · A)
B_bar = (I - Delta/2 · A)^(-1) · Delta · B
이산화 후 시스템은:
x[t] = A_bar · x[t-1] + B_bar · u[t]
y[t] = C · x[t]
이 형태는 선형 순환 (Linear Recurrence)으로, 각 타임스텝에서 상태를 업데이트합니다.
합성곱 커널로서의 SSM
SSM의 강력한 점은 훈련 시 합성곱으로 계산 가능하다는 것입니다. 초기 상태 x[0] = 0 으로 시작하면:
y[0] = C · B_bar · u[0]
y[1] = C · A_bar · B_bar · u[0] + C · B_bar · u[1]
y[2] = C · A_bar^2 · B_bar · u[0] + C · A_bar · B_bar · u[1] + C · B_bar · u[2]
...
이를 합성곱 커널 K로 표현하면:
K = (C·B_bar, C·A_bar·B_bar, C·A_bar^2·B_bar, ...)
y = K * u (합성곱)
이 커널 K는 병렬로 계산할 수 있으며, FFT를 사용하면 O(n log n)으로 계산 가능합니다.
이것이 SSM의 핵심 이중성입니다:
- 추론: 순환(Recurrence)으로 O(1) 메모리
- 훈련: 합성곱(Convolution)으로 O(n log n) 병렬 계산
3. S4 (Structured State Space Sequence Model)
S4는 2021년 Albert Gu 등이 발표한 논문으로, SSM을 딥러닝에 실용적으로 적용한 첫 번째 중요한 작업입니다.
HiPPO 행렬 초기화
S4의 핵심 기여 중 하나는 HiPPO(High-order Polynomial Projection Operators) 행렬 초기화입니다. 단순히 무작위로 A를 초기화하면 기울기 소실 문제가 발생합니다. HiPPO는 과거 입력을 다항식으로 근사하도록 설계된 특별한 A 행렬을 제공합니다.
HiPPO-LegS (Legendre 다항식 기반):
A[n,k] = -sqrt((2n+1)(2k+1)) if n > k
A[n,k] = -(n+1) if n == k
A[n,k] = 0 if n < k
이 초기화는 상태 x[t]가 모든 과거 입력 u[0..t]에 대한 최적 다항식 근사를 유지하게 합니다.
구조화된 행렬 A (DPLR)
A 행렬의 합성곱 커널을 효율적으로 계산하기 위해, S4는 A를 Diagonal Plus Low-Rank (DPLR) 형태로 표현합니다:
A = Λ - P·Q^T
여기서 Λ는 대각 행렬이고 P, Q는 낮은 랭크의 벡터입니다. 이 구조를 이용하면 커널 계산을 O(N)으로 줄일 수 있습니다.
효율적 계산
import torch
import torch.nn as nn
import numpy as np
class S4Layer(nn.Module):
"""S4 레이어의 간소화된 구현"""
def __init__(self, d_model, d_state=64, dropout=0.0):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# HiPPO-LegS 초기화
A = self._make_hippo(d_state)
# DPLR 분해
self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32))
self.B = nn.Parameter(torch.randn(d_state, 1) * 0.01)
self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)
self.log_delta = nn.Parameter(torch.zeros(d_model))
self.D = nn.Parameter(torch.ones(d_model))
def _make_hippo(self, N):
"""HiPPO-LegS 행렬 생성"""
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
def discretize(self, A, B, C, delta):
"""ZOH 이산화"""
dA = torch.matrix_exp(delta.unsqueeze(-1) * A)
dB = torch.linalg.solve(A, (dA - torch.eye(A.shape[0])) @ B)
return dA, dB, C
def forward(self, u):
# u: (batch, seq_len, d_model)
delta = torch.exp(self.log_delta)
# 이산화 및 합성곱 계산
# 실제 구현은 더 복잡하지만 여기서는 개념 설명
return u # 간소화
4. H3 (Hungry Hungry Hippos)
2022년 발표된 H3는 S4를 언어 모델링에 더 적합하게 개선한 모델입니다. 제목은 HiPPO의 줄임말을 유머러스하게 표현한 것입니다.
S4와의 차이점
S4는 긴 범위 의존성을 잘 포착하지만, 언어 모델링에서 중요한 토큰 간 상호작용이 부족했습니다. 어텐션은 "이 단어가 저 단어와 얼마나 관련 있나?"를 직접 계산하지만, S4는 순수한 순환으로 이를 포착하기 어렵습니다.
게이팅 메커니즘
H3는 두 개의 SSM을 사용하고 그 사이에 게이팅을 추가합니다:
class H3Layer(nn.Module):
"""H3 레이어"""
def __init__(self, d_model, d_state=64):
super().__init__()
# 두 개의 SSM (shift SSM + diagonal SSM)
self.shift_ssm = S4Layer(d_model, d_state)
self.diag_ssm = S4Layer(d_model, d_state)
# 프로젝션
self.Q_proj = nn.Linear(d_model, d_model)
self.K_proj = nn.Linear(d_model, d_model)
self.V_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
# Q, K, V 프로젝션
q = self.Q_proj(x)
k = self.K_proj(x)
v = self.V_proj(x)
# SSM 1: K에 Shift SSM 적용
k_ssm = self.shift_ssm(k)
# 곱셈 게이팅: Q ⊙ K_ssm
gated = q * k_ssm
# SSM 2: V에 Diagonal SSM 적용 후 gated와 곱
v_ssm = self.diag_ssm(v)
output = gated * v_ssm
return self.out_proj(output)
언어 모델링 개선
H3는 GPT-2 크기의 언어 모델에서 Transformer에 근접한 성능을 달성하면서도, 추론 시 훨씬 빠른 속도를 보였습니다. 이는 SSM이 언어 모델링에서 실용적임을 보여주는 중요한 이정표였습니다.
5. Mamba (Selective State Space Model)
2023년 12월, Albert Gu와 Tri Dao가 발표한 Mamba는 SSM 연구의 가장 중요한 진전입니다. Mamba의 핵심 혁신은 **선택적 메커니즘(Selective Mechanism)**입니다.
핵심 혁신: 선택적 메커니즘
기존 S4와 H3의 근본적인 한계는 행렬 A, B, C가 **입력에 독립적(input-independent)**이었다는 점입니다. 합성곱으로 훈련하기 위해서는 커널이 고정되어야 했기 때문입니다.
Mamba는 이 제약을 깨고 **S6 (Selective SSM)**를 도입합니다: B, C, Delta 행렬을 입력 u(t)의 함수로 만들었습니다.
B(t) = Linear_B(u(t)) # 입력에 따라 달라지는 B
C(t) = Linear_C(u(t)) # 입력에 따라 달라지는 C
Delta(t) = softplus(Linear_Delta(u(t))) # 입력에 따라 달라지는 Delta
이렇게 하면 모델이 어떤 정보를 상태에 저장하고 어떤 정보를 무시할지를 입력에 따라 동적으로 결정할 수 있습니다.
직관적으로:
- Delta가 크면 → 현재 입력이 상태에 강하게 반영됨 (정보 기억)
- Delta가 작으면 → 상태 변화가 적음 (정보 무시)
이는 LSTM의 게이트(input gate, forget gate)와 유사한 역할을 하지만, 연속 시간 SSM 프레임워크 안에서 이루어집니다.
Hardware-Aware Algorithm
선택적 메커니즘을 도입하면 B, C가 입력에 의존하므로 더 이상 합성곱으로 계산할 수 없습니다. 이는 심각한 계산 비효율을 초래할 수 있습니다.
Mamba는 이를 해결하기 위해 Hardware-Aware Algorithm을 사용합니다. 이는 GPU의 메모리 계층 구조(HBM vs SRAM)를 활용한 Kernel Fusion 기법입니다:
문제: 각 타임스텝의 순환 계산에서 중간 상태를 HBM(느린 메모리)에 저장하면 메모리 대역폭 병목이 발생
해결:
- 모든 중간 계산을 빠른 SRAM(온칩 메모리)에서 수행
- 완전한 합성곱이 아닌 Parallel Scan 알고리즘 사용
- 최종 출력만 HBM에 저장
# Parallel Scan의 개념 (실제 구현은 CUDA로)
def parallel_scan(gates, tokens):
"""
선형 순환을 병렬로 계산
x[t] = gates[t] * x[t-1] + tokens[t]
이진 트리 구조로 O(log n) 병렬 깊이 달성
"""
n = len(tokens)
# 상향(up-sweep) 단계
log_n = int(np.log2(n))
for d in range(log_n):
step = 2 ** (d + 1)
for i in range(step - 1, n, step):
gates[i] = gates[i] * gates[i - 2**d]
tokens[i] = gates[i - 2**d] * tokens[i - 2**d] + tokens[i]
# 하향(down-sweep) 단계 (생략)
return tokens
Mamba 블록 구조
입력 x (B, L, D)
│
├──────────────────────┐
│ │
Linear(D→ED) Linear(D→ED)
+ SiLU activation │
│ SSM (S6)
│ │
└─────── ⊙ ────────────┘
(element-wise 곱셈)
│
Linear(ED→D)
│
출력 y (B, L, D)
여기서 E는 확장 비율(보통 2), D는 모델 차원입니다.
완전한 Mamba 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class MambaBlock(nn.Module):
"""
Mamba Block 구현
논문: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
"""
def __init__(
self,
d_model, # 모델 차원 D
d_state=16, # SSM 상태 차원 N
d_conv=4, # 합성곱 커널 크기
expand=2, # 확장 비율 E
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
bias=False,
):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
if dt_rank == "auto":
self.dt_rank = max(1, int(d_model / 16))
else:
self.dt_rank = dt_rank
# 입력 프로젝션 (D → 2*ED, 두 브랜치를 한번에 계산)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)
# 로컬 합성곱 (depthwise conv)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=True,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1, # causal padding
)
# SSM 파라미터 프로젝션
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# A 초기화 (HiPPO 기반)
A = repeat(
torch.arange(1, self.d_state + 1),
"n -> d n",
d=self.d_inner
)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
# D (skip connection)
self.D = nn.Parameter(torch.ones(self.d_inner))
self.D._no_weight_decay = True
# 출력 프로젝션
self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
def forward(self, hidden_states):
"""
hidden_states: (B, L, D)
Returns: (B, L, D)
"""
batch, seqlen, dim = hidden_states.shape
# 입력 프로젝션
xz = self.in_proj(hidden_states)
x, z = xz.chunk(2, dim=-1) # 각각 (B, L, ED)
# 합성곱 (causal 1D conv)
x = rearrange(x, "b l d -> b d l")
x = self.conv1d(x)[:, :, :seqlen] # causal trimming
x = rearrange(x, "b d l -> b l d")
x = F.silu(x)
# SSM
y = self.ssm(x)
# 게이팅 (z 브랜치와 element-wise 곱)
y = y * F.silu(z)
# 출력 프로젝션
output = self.out_proj(y)
return output
def ssm(self, x):
"""선택적 SSM (S6) 계산"""
d_in, n = self.d_inner, self.d_state
# A 행렬 (-exp(-A_log)로 항상 음수 보장 = 안정성)
A = -torch.exp(self.A_log.float()) # (ED, N)
# x_proj: Delta, B, C 계산
x_dbl = self.x_proj(x) # (B, L, dt_rank + 2N)
delta, B, C = x_dbl.split(
[self.dt_rank, n, n], dim=-1
)
# Delta: softplus로 양수 보장
delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
# ZOH 이산화 및 선택적 스캔
y = self.selective_scan(x, delta, A, B, C, self.D)
return y
def selective_scan(self, u, delta, A, B, C, D):
"""
선택적 스캔 알고리즘
실제 구현에서는 mamba_ssm의 CUDA 커널 사용
여기서는 순수 PyTorch로 개념 설명
"""
b, l, d_in = u.shape
n = A.shape[1]
# 이산화
# delta: (B, L, ED), A: (ED, N) → dA: (B, L, ED, N)
deltaA = torch.exp(
torch.einsum("bld,dn->bldn", delta, A)
)
# delta: (B, L, ED), B: (B, L, N), u: (B, L, ED)
# → deltaB_u: (B, L, ED, N)
deltaB_u = torch.einsum("bld,bln,bld->bldn", delta, B, u)
# 순환 스캔
x = torch.zeros(b, d_in, n, device=u.device, dtype=u.dtype)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = torch.einsum("bdn,bn->bd", x, C[:, i, :])
ys.append(y)
y = torch.stack(ys, dim=1) # (B, L, ED)
# D (skip connection)
y = y + u * D
return y
class MambaModel(nn.Module):
"""스택된 Mamba 블록으로 구성된 시퀀스 모델"""
def __init__(self, d_model, n_layers, d_state=16, expand=2):
super().__init__()
self.layers = nn.ModuleList([
MambaBlock(d_model, d_state=d_state, expand=expand)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
"""x: (B, L, D)"""
for layer in self.layers:
x = x + layer(self.norm(x)) # Pre-norm residual
return x
6. Mamba 2
2024년 5월, Tri Dao와 Albert Gu는 Mamba 2를 발표하면서 이론적 기반을 더욱 강화했습니다.
Mamba 1과의 차이
Mamba 1의 선택적 스캔은 각 타임스텝에서 순차적 계산이 필요했습니다. Mamba 2는 **State Space Duality (SSD)**를 발견하여 이를 더욱 효율적으로 만들었습니다.
State Space Duality (SSD)
Mamba 2의 핵심 통찰: 특정 구조를 가진 SSM은 **세미 분리 가능 행렬(Semi-Separable Matrix)**로 표현되며, 이는 특정 형태의 어텐션과 동일합니다.
이 수학적 동치(duality)를 통해:
- SSM 계산을 행렬 곱셈 형태로 표현 가능
- 고도로 최적화된 BLAS 라이브러리 활용 가능
- Tensor Core 활용으로 GPU 효율 극대화
# SSD 연산의 개념
# SSM 순환: x[t] = A[t] * x[t-1] + B[t] * u[t]
# y[t] = C[t]^T * x[t]
#
# 이를 청크(chunk) 단위로 처리:
# - 청크 내부: 행렬 곱셈으로 병렬 계산
# - 청크 간: 순환으로 상태 전파
class Mamba2Block(nn.Module):
"""Mamba 2 블록"""
def __init__(self, d_model, d_state=64, n_heads=8, chunk_size=64):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.n_heads = n_heads
self.chunk_size = chunk_size
self.d_head = d_model // n_heads
# 멀티헤드 구조
self.norm = nn.LayerNorm(d_model)
self.in_proj = nn.Linear(d_model, d_model * 2 + d_state * 2 + n_heads)
self.out_proj = nn.Linear(d_model, d_model)
# A 파라미터 (헤드별)
self.A_log = nn.Parameter(torch.randn(n_heads))
def forward(self, x):
"""x: (B, L, D)"""
# 실제 구현은 mamba_ssm 패키지의 CUDA 커널 사용
return x
다중 헤드 구조
Mamba 2는 다중 헤드 SSM을 도입하여 Transformer의 Multi-Head Attention과 구조적으로 유사해졌습니다. 이는:
- 더 나은 표현력
- 어텐션과의 이론적 연결
- 하이브리드 아키텍처 설계 용이성
Transformer와의 통합 가능성
SSD 이론은 특정 SSM이 특정 어텐션과 동일함을 보여줍니다. 이는 Mamba와 Transformer를 혼합하는 하이브리드 아키텍처의 이론적 근거를 제공합니다.
7. 하이브리드 아키텍처
MambaFormer
MambaFormer는 Mamba 블록과 Transformer 블록을 인터리빙(interleaving)합니다:
레이어 1: Mamba Block (로컬 패턴 포착)
레이어 2: Attention (전역 의존성 포착)
레이어 3: Mamba Block
레이어 4: Attention
...
이 구조는 각 컴포넌트의 장점을 활용합니다:
- Mamba: 효율적인 시퀀스 처리, 로컬 패턴
- Attention: 선택적 정보 검색, 전역 의존성
Jamba (SSM + Transformer + MoE)
2024년 AI21 Labs가 발표한 Jamba는 더 복잡한 하이브리드입니다:
Jamba = Mamba + Transformer + Mixture of Experts
아키텍처:
- 52B 파라미터 (활성화: 12B)
- 레이어 비율: Attention 1 : Mamba 7
- 일부 레이어에 MoE 적용
- 256K 컨텍스트 윈도우 지원
Jamba는 동일 크기 Transformer 대비:
- 추론 처리량 3배 향상
- 긴 컨텍스트에서 메모리 효율 대폭 개선
# Jamba 스타일 하이브리드 블록 (개념)
class JambaLayer(nn.Module):
def __init__(self, d_model, layer_idx, attn_every_n=8):
super().__init__()
self.use_attention = (layer_idx % attn_every_n == 0)
if self.use_attention:
self.mixer = nn.MultiheadAttention(d_model, num_heads=8)
else:
self.mixer = MambaBlock(d_model)
# MoE (일부 레이어에만)
self.use_moe = (layer_idx % 2 == 0)
if self.use_moe:
self.ffn = MixtureOfExperts(d_model, n_experts=16)
else:
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.SiLU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
x = x + self.mixer(x)
x = x + self.ffn(x)
return x
RWKV (선형 어텐션)
RWKV는 Transformer와 RNN의 하이브리드입니다. "Receptance Weighted Key Value"의 약자로:
- 훈련: Transformer처럼 병렬화 (행렬 형태)
- 추론: RNN처럼 순환 (O(1) 상태)
- 어텐션 없이 토큰 믹싱
최신 버전(RWKV-6, RWKV-7)은 Mamba와 경쟁하는 성능을 보입니다.
RetNet (Retentive Network)
Microsoft의 RetNet은 "Training Parallelism, Inference Efficiency, Competitive Performance"를 동시에 달성하는 것을 목표로 합니다:
3가지 계산 패러다임:
1. Parallel: 훈련 시 O(n^2) 병렬 계산 (Transformer 대비 낮은 상수)
2. Recurrent: 추론 시 O(1) 메모리
3. Chunkwise Recurrent: 균형 잡힌 중간 방법
8. Mamba 성능 비교
추론 속도 (선형 스케일)
Mamba의 가장 큰 강점은 시퀀스 길이에 따른 선형 스케일입니다.
시퀀스 길이 2K일 때 성능을 1로 정규화:
시퀀스 길이 Transformer Mamba
1K 0.5x 0.5x
2K 1.0x 1.0x
4K 3.5x 2.0x ← Transformer 대비 1.75배 빠름
8K 13x 4.0x ← Transformer 대비 3.25배 빠름
16K 50x 8.0x ← Transformer 대비 6.25배 빠름
100K ~1800x ~50x ← Transformer 대비 36배 빠름
메모리 효율성
추론 시 상태(state) 크기 비교:
모델 추론 시 메모리 (1K 토큰당)
Transformer O(n) KV Cache
Mamba O(1) SSM 상태 (고정 크기)
예: 130M 파라미터 모델, 1M 토큰 시퀀스
- Transformer: ~16GB KV Cache
- Mamba: ~1MB 상태 (상수!)
긴 시퀀스 태스크
Long Range Arena (LRA) 벤치마크 (시퀀스 길이 1K-16K):
모델 ListOps Text Retrieval Image Path-X 평균
Transformer 36.4 65.0 57.5 42.4 0.0 40.3
LSTM 35.9 63.7 65.0 43.3 0.0 41.6
S4 59.6 86.8 90.9 88.7 86.1 82.4
Mamba ~S4 수준의 성능, 더 빠른 처리
Mamba는 특히 선택적 복사(Selective Copy), 유도 헤드(Induction Head) 같은 합성 벤치마크에서 S4를 크게 앞서며, 이는 실제 언어 모델링 능력과 더 관련이 깊습니다.
9. Mamba 활용 분야
자연어 처리
Mamba 기반 언어 모델들이 속속 등장하고 있습니다:
- MambaChat: 대화형 AI 어시스턴트
- Falcon Mamba: TII가 오픈소스로 발표한 7B Mamba 모델
- CodeMamba: 코드 생성 특화
긴 문서 처리, 요약, 번역에서 Transformer 대비 효율적입니다.
바이오인포매틱스
게놈 시퀀스 분석은 매우 긴 시퀀스(수백만 bp)를 다루므로 Mamba가 특히 유리합니다:
- Caduceus: 긴 DNA 시퀀스 모델링
- Hyena: 긴 시퀀스 DNA/단백질 분석
- 단백질 구조 예측에서의 활용 가능성
# 바이오 시퀀스를 위한 Mamba 활용 예시
from mamba_ssm import MambaLMHeadModel
# DNA 시퀀스 모델링 (A, T, G, C, N 토큰)
DNA_VOCAB = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}
VOCAB_SIZE = len(DNA_VOCAB)
# 매우 긴 시퀀스도 효율적으로 처리
model = MambaLMHeadModel.from_pretrained(
"state-spaces/mamba-130m",
device="cuda",
dtype=torch.float16
)
시계열 분석
금융, 기상, IoT 센서 데이터 등 긴 시계열에서 Mamba는 강점을 보입니다:
- TimeMamba: 긴 시계열 예측
- MambaMixer: 다변량 시계열 모델링
- S4/S5 기반 시계열 모델들의 발전
이미지 처리 (VMamba)
VMamba는 Mamba를 2D 이미지 처리에 확장합니다:
# VMamba의 핵심: 2D 선택적 스캔
# 이미지를 4방향으로 스캔하여 2D 구조 포착
# 방향:
# 1. 좌→우, 상→하 (일반적 래스터 스캔)
# 2. 우→좌, 하→상 (역방향)
# 3. 상→하, 좌→우 (열 우선)
# 4. 하→상, 우→좌 (역방향)
class VMambaBlock(nn.Module):
"""VMamba: Visual Mamba Block"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.norm = nn.LayerNorm(d_model)
# 4방향 SSM
self.ssms = nn.ModuleList([
MambaBlock(d_model, d_state) for _ in range(4)
])
self.out_proj = nn.Linear(d_model * 4, d_model)
def forward(self, x):
"""x: (B, H, W, D) - 이미지 패치 임베딩"""
b, h, w, d = x.shape
x_flat = x.view(b, h*w, d)
outputs = []
# 4방향 스캔
for i, ssm in enumerate(self.ssms):
if i == 0: # 정방향
seq = x_flat
elif i == 1: # 역방향
seq = x_flat.flip(1)
elif i == 2: # 열 우선
seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d)
else: # 열 우선 역방향
seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d).flip(1)
out = ssm(seq)
if i % 2 == 1:
out = out.flip(1)
outputs.append(out)
# 4방향 결과 합산
combined = torch.cat(outputs, dim=-1)
return self.out_proj(combined).view(b, h, w, d)
10. 실전 사용
mamba-ssm 패키지 설치
# 필요 패키지
pip install torch torchvision torchaudio
pip install causal-conv1d>=1.2.0
pip install mamba-ssm
# 또는 소스에서 설치 (최신 기능)
git clone https://github.com/state-spaces/mamba
cd mamba
pip install -e ".[dev]"
MambaLMHeadModel 사용
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer
# 사전 학습된 모델 로드
model_name = "state-spaces/mamba-2.8b-slimpj"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(
model_name,
device="cuda",
dtype=torch.bfloat16
)
model.eval()
# 텍스트 생성
def generate_text(prompt, max_new_tokens=200, temperature=0.7):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
)
generated = output[0][input_ids.shape[1]:]
return tokenizer.decode(generated, skip_special_tokens=True)
# 예시
prompt = "State Space Models are powerful because"
result = generate_text(prompt)
print(result)
파인튜닝 예제
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=512):
self.encodings = tokenizer(
texts,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt"
)
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
return {
'input_ids': self.encodings['input_ids'][idx],
'attention_mask': self.encodings['attention_mask'][idx],
}
def finetune_mamba(
model_name="state-spaces/mamba-130m",
texts=None,
num_epochs=3,
learning_rate=1e-4,
batch_size=8,
):
"""Mamba 파인튜닝"""
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token = tokenizer.eos_token
# 모델 로드
model = MambaLMHeadModel.from_pretrained(
model_name,
device=device,
dtype=torch.bfloat16
)
# 데이터셋
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 옵티마이저 (Mamba에는 AdamW 권장)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=0.1
)
# 스케줄러
total_steps = len(dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=total_steps // 10,
num_training_steps=total_steps
)
# 훈련 루프
model.train()
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, batch in enumerate(dataloader):
input_ids = batch['input_ids'].to(device)
# 언어 모델링: 다음 토큰 예측
outputs = model(input_ids)
logits = outputs.logits
# Shift for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = nn.CrossEntropyLoss()(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f"Epoch {epoch+1}, Step {batch_idx}: Loss = {loss.item():.4f}")
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1} 완료: 평균 Loss = {avg_loss:.4f}")
return model
커스텀 Mamba 모델 구성
from mamba_ssm import Mamba
from mamba_ssm.models.config_mamba import MambaConfig
# 커스텀 설정
config = MambaConfig(
d_model=1024,
n_layer=48,
vocab_size=50280,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
)
# 모델 생성 (약 1.4B 파라미터)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.bfloat16)
print(f"파라미터 수: {sum(p.numel() for p in model.parameters()):,}")
마무리: Mamba의 미래
Mamba는 딥러닝 시퀀스 모델링 분야에 혁명적인 변화를 가져왔습니다. 핵심 기여를 정리하면:
- 선택적 메커니즘: 입력에 따라 동적으로 변하는 SSM 파라미터
- Hardware-Aware 설계: GPU 메모리 계층을 활용한 효율적 계산
- 이중 표현: 훈련 시 병렬화, 추론 시 순환으로 최적화
- 선형 복잡도: 시퀀스 길이에 선형적인 계산 및 메모리 복잡도
아직 해결해야 할 과제도 있습니다:
- In-context Learning에서 Transformer보다 다소 부족한 성능
- 매우 큰 스케일에서의 검증 필요
- 어텐션 기반 모델과의 명확한 우위 입증
그러나 Mamba와 SSM 계열의 모델들은 긴 시퀀스 처리, 실시간 추론, 엣지 디바이스 배포 등 Transformer가 어려움을 겪는 분야에서 점점 더 중요한 역할을 하게 될 것입니다.
참고 자료
- Mamba 논문: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023)
- Mamba 2 논문: "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" (Dao & Gu, 2024)
- S4 논문: "Efficiently Modeling Long Sequences with Structured State Spaces" (Gu et al., 2021)
- H3 논문: "Hungry Hungry Hippos: Towards Language Modeling with State Space Models" (Fu et al., 2022)
- Mamba GitHub: https://github.com/state-spaces/mamba
- Jamba 논문: "Jamba: A Hybrid Transformer-Mamba Language Model" (AI21 Labs, 2024)