Skip to content

필사 모드: Mamba 논문 리뷰: Selective State Space Models로 Transformer를 넘어서

한국어
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

1. 서론: Transformer의 한계와 SSM의 부상

Transformer 아키텍처는 2017년 등장 이후 NLP, Vision, Audio 등 거의 모든 시퀀스 모델링 분야를 지배해왔다. 하지만 근본적인 한계가 있다:

- **O(N²) 복잡도**: Self-Attention의 시간/공간 복잡도가 시퀀스 길이에 이차적으로 증가

- **추론 비효율**: 각 토큰 생성 시 전체 KV Cache를 참조해야 함

- **긴 시퀀스 처리 어려움**: 128K+ 컨텍스트에서 메모리 병목

**State Space Model(SSM)**은 이러한 한계를 극복하기 위해 제안된 대안이다. 특히 **Mamba**(Gu & Dao, ICLR 2024)는 SSM에 **선택적(Selective)** 메커니즘을 도입하여 Transformer에 필적하는 성능을 **O(N) 복잡도**로 달성했다.

2. 배경: Structured State Space Models (S4)

2.1 연속 시간 SSM

SSM은 연속 시간 시스템을 이산화한 것이다:

$$

h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)

$$

$$

y(t) = \mathbf{C}h(t)

$$

- $h(t) \in \mathbb{R}^N$: hidden state

- $x(t) \in \mathbb{R}$: input

- $y(t) \in \mathbb{R}$: output

- $\mathbf{A} \in \mathbb{R}^{N \times N}$, $\mathbf{B} \in \mathbb{R}^{N \times 1}$, $\mathbf{C} \in \mathbb{R}^{1 \times N}$

2.2 이산화 (Zero-Order Hold)

연속 시스템을 이산 시간으로 변환:

$$

\overline{\mathbf{A}} = \exp(\Delta \mathbf{A})

$$

$$

\overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}

$$

이산화된 재귀 관계:

$$

h_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_t

$$

$$

y_t = \mathbf{C} h_t

$$

2.3 S4의 핵심: HiPPO 초기화

S4(Structured State Spaces for Sequence Modeling)의 핵심 기여는 **HiPPO(High-order Polynomial Projection Operator)** 행렬로 A를 초기화한 것이다:

def make_hippo(N):

"""HiPPO-LegS 행렬 생성"""

P = np.sqrt(1 + 2 * np.arange(N))

A = np.zeros((N, N))

for i in range(N):

for j in range(N):

if i > j:

A[i, j] = P[i] * P[j]

elif i == j:

A[i, j] = i + 1

i < j: 0

return -A # 안정성을 위해 음수

2.4 S4의 한계

S4는 **입력에 독립적인 고정 파라미터**(A, B, C, Δ)를 사용한다. 이는 두 가지 문제를 야기한다:

1. **Content-based reasoning 불가**: 입력 내용에 따라 다르게 처리해야 하는 태스크(예: selective copying, induction heads)에서 성능 저하

2. **Attention의 장점 포기**: Transformer의 content-aware 매칭 능력 상실

3. Mamba: Selective State Space Models

3.1 핵심 아이디어: Selection Mechanism

Mamba의 핵심 혁신은 **SSM 파라미터를 입력에 의존하게** 만든 것이다:

S4: (A, B, C, Δ) = 고정 파라미터

Mamba: (B, C, Δ) = f(input) ← 입력 의존!

구체적으로:

$$

\mathbf{B}_t = \text{Linear}_B(x_t), \quad \mathbf{C}_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))

$$

이렇게 하면 모델이 **어떤 정보를 기억하고 어떤 정보를 잊을지** 입력에 따라 동적으로 결정할 수 있다.

3.2 Selection의 직관

$\Delta_t$(step size)가 핵심:

- **$\Delta_t$가 크면**: $\overline{\mathbf{A}} \approx \mathbf{I}$ → 이전 state 유지 (입력 무시)

- **$\Delta_t$가 작으면**: 현재 입력에 더 집중 (state 갱신)

이것은 **게이팅 메커니즘**과 유사하다:

LSTM의 forget gate ≈ Mamba의 Δ

3.3 Hardware-Aware 알고리즘

Selection을 도입하면 파라미터가 입력 의존이 되어 **convolution trick을 사용할 수 없다**. 이는 학습 시 O(N²) 복잡도로 회귀할 위험이 있다.

Mamba는 이를 **kernel fusion + recomputation**으로 해결:

Mamba의 Selective Scan (simplified)

def selective_scan(x, delta, A, B, C):

"""

x: (B, L, D) - 입력

delta: (B, L, D) - step size (입력 의존)

A: (D, N) - state matrix

B: (B, L, N) - input matrix (입력 의존)

C: (B, L, N) - output matrix (입력 의존)

"""

B_batch, L, D = x.shape

N = A.shape[1]

이산화

deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)

deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)

Sequential scan (학습 시 parallel scan 사용)

h = torch.zeros(B_batch, D, N, device=x.device)

ys = []

for t in range(L):

h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]

y = (h * C[:, t, None, :]).sum(-1) # (B, D)

ys.append(y)

return torch.stack(ys, dim=1) # (B, L, D)

실제 구현에서는 **CUDA 커널에서 SRAM에 모든 중간 상태를 유지**하여 HBM 접근을 최소화한다 (FlashAttention과 유사한 IO-aware 접근).

3.4 Mamba Block 아키텍처

Input

├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐

│ │

└──── Linear(D → ED) ──── SiLU ──────────────────────────── × ───────┘

Linear(ED → D)

Output

class MambaBlock(nn.Module):

def __init__(self, d_model, d_state=16, d_conv=4, expand=2):

super().__init__()

d_inner = d_model * expand

self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,

padding=d_conv-1, groups=d_inner)

SSM 파라미터

self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, Δ

self.dt_proj = nn.Linear(1, d_inner, bias=True)

A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)

self.A_log = nn.Parameter(torch.log(A))

self.D = nn.Parameter(torch.ones(d_inner))

self.out_proj = nn.Linear(d_inner, d_model, bias=False)

def forward(self, x):

x: (B, L, D)

xz = self.in_proj(x) # (B, L, 2*ED)

x, z = xz.chunk(2, dim=-1)

Conv + SiLU

x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]

x = x.transpose(1, 2)

x = F.silu(x)

SSM

A = -torch.exp(self.A_log)

B_C_dt = self.x_proj(x)

... selective scan ...

Gate

y = y * F.silu(z)

return self.out_proj(y)

4. Mamba-2: State Space Duality

4.1 SSD (State Space Dual) 모델

Mamba-2(Dao & Gu, ICML 2024)는 SSM과 Attention 사이의 수학적 이중성(Duality)을 발견했다:

- **SSM 관점**: 재귀적 state update → O(N) 추론

- **Attention 관점**: 특수한 구조의 행렬 곱 → 병렬 학습

SSM recurrence: h_t = A_t h_{t-1} + B_t x_t

y_t = C_t h_t

↕ Dual ↕

Structured Attention: Y = (T ⊙ M) X

where T = lower-triangular causal mask

M = structured (semi-separable) matrix

4.2 성능 비교

| 모델 | 파라미터 | Pile (ppl) | 학습 FLOPS/s | 추론 (tok/s) |

|------|---------|------------|-------------|-------------|

| Transformer++ | 2.7B | 6.7 | 1.0x | 1.0x |

| Mamba-1 | 2.8B | 6.2 | 1.3x | 5.2x |

| Mamba-2 | 2.7B | 6.1 | **2.1x** | **5.5x** |

핵심 개선:

- 학습 속도 2배 향상 (structured matrix multiplication 활용)

- 더 큰 state size 가능 (N=64 → N=256)

- Multi-head 구조 도입

5. 실전: Mamba 모델 사용하기

5.1 설치 및 추론

pip install mamba-ssm causal-conv1d>=1.4.0

from mamba_ssm import MambaLMHeadModel

from transformers import AutoTokenizer

Mamba-2.8B 로드

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = model.cuda().half()

추론

input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()

output = model.generate(input_ids, max_length=100, temperature=0.7)

print(tokenizer.decode(output[0]))

5.2 Mamba-2 사용

from mamba_ssm import Mamba2

Mamba-2 레이어 단독 사용

layer = Mamba2(

d_model=2048,

d_state=128, # Mamba-2는 더 큰 state 가능

d_conv=4,

expand=2,

headdim=64, # Multi-head

).cuda()

x = torch.randn(1, 1024, 2048).cuda() # (batch, seq_len, d_model)

y = layer(x) # (1, 1024, 2048)

6. Mamba vs Transformer 비교

| 특성 | Transformer | Mamba |

|------|------------|-------|

| 학습 복잡도 | O(N²) | O(N) |

| 추론 복잡도 | O(N) per token (KV Cache) | O(1) per token |

| 메모리 (추론) | O(N) KV Cache | O(1) fixed state |

| In-context learning | ✅ 강함 | ⚠️ 약함 (개선 중) |

| 긴 시퀀스 | ❌ 병목 | ✅ 효율적 |

| 하드웨어 활용 | ✅ 병렬화 최적 | ⚠️ 재귀 병목 (Mamba-2 개선) |

7. 한계와 전망

7.1 현재 한계

- **In-context learning**: 고정 크기 state로 무한한 컨텍스트 정보 압축 시 손실

- **Retrieval 태스크**: 정확한 토큰 검색이 필요한 경우 Attention이 우위

- **학습 안정성**: 긴 시퀀스에서 gradient 불안정 문제

7.2 하이브리드 접근

최근 트렌드는 **Mamba + Attention 하이브리드**:

- **Jamba** (AI21): Mamba 레이어 + Attention 레이어 혼합

- **Zamba** (Zyphra): Mamba 기반 + 소수의 Shared Attention 레이어

8. 퀴즈

SSM 파라미터(B, C, Δ)를 **입력 의존적으로** 만들어 content-based reasoning을 가능하게 한 것. S4는 고정 파라미터여서 입력 내용에 따른 선택적 처리가 불가했다.

Step size로, **게이팅 메커니즘** 역할. Δ가 크면 이전 state 유지(입력 무시), 작으면 현재 입력에 집중(state 갱신). LSTM의 forget gate와 유사.

Convolution trick은 **시간 불변(time-invariant)** 시스템에서만 작동. Selection으로 파라미터가 입력 의존이 되면 시간 가변(time-varying)이 되어 글로벌 컨볼루션 불가.

CUDA 커널에서 **SRAM에 중간 state를 유지**하여 HBM 접근 최소화. FlashAttention과 유사한 IO-aware 접근으로 memory-bound 연산 최적화.

SSM의 재귀적 state update와 structured attention 행렬 곱이 **수학적으로 동일**하다는 것. 이를 통해 학습 시 병렬화된 행렬 곱, 추론 시 효율적 재귀를 모두 활용.

In-context learning과 정확한 토큰 검색(retrieval) 태스크. Attention은 시퀀스 내 임의 위치를 직접 참조할 수 있지만, Mamba는 고정 크기 state로 정보를 압축해야 한다.

Mamba 레이어의 O(N) 효율성 + 소수의 Attention 레이어의 정확한 검색 능력을 결합. 대부분의 레이어를 Mamba로, 핵심 위치에만 Attention을 배치.

현재 단락 (1/170)

Transformer 아키텍처는 2017년 등장 이후 NLP, Vision, Audio 등 거의 모든 시퀀스 모델링 분야를 지배해왔다. 하지만 근본적인 한계가 있다:

작성 글자: 0원문 글자: 6,343작성 단락: 0/170