Skip to content
Published on

Mamba: Linear-Time Sequence Modeling with Selective State Spaces 논문 분석

Authors
  • Name
    Twitter
Mamba SSM

논문 개요

  • 제목: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
  • 저자: Albert Gu, Tri Dao (Carnegie Mellon University, Princeton University)
  • 발표: 2023년 12월 (arXiv: 2312.00752), ICML 2024 채택
  • 코드: github.com/state-spaces/mamba

동기: Transformer의 한계

Transformer는 시퀀스 모델링의 왕좌를 차지하고 있지만, 근본적인 문제가 있습니다:

  • O(N²) 복잡도: Self-attention은 시퀀스 길이에 대해 이차 복잡도
  • 긴 시퀀스 처리 한계: 128K 이상의 컨텍스트에서 메모리/연산 폭발
  • 추론 비효율: KV Cache가 시퀀스 길이에 비례하여 증가

State Space Model(SSM)은 이론적으로 O(N) 복잡도를 가지지만, 기존 SSM(S4, H3 등)은 입력에 무관한 고정 파라미터 때문에 언어 모델링에서 Transformer에 뒤졌습니다.

핵심 아이디어: Selective State Spaces

기존 SSM의 문제

기존 SSM은 Linear Time-Invariant(LTI) 시스템입니다:

h(t)=Ah(t)+Bx(t)h'(t) = Ah(t) + Bx(t) y(t)=Ch(t)y(t) = Ch(t)

여기서 A, B, C는 고정된 행렬입니다. 즉, 입력이 "hello"든 "world"든 같은 방식으로 처리합니다. 이것은 content-aware reasoning이 불가능하다는 의미입니다.

Mamba의 해결책: Selection Mechanism

Mamba의 핵심 혁신은 B, C, Δ를 입력에 의존적으로 만드는 것입니다:

# 기존 SSM: 고정 파라미터
B = nn.Parameter(torch.randn(N))  # 학습되지만 입력과 무관

# Mamba: 입력 의존적 파라미터
B = nn.Linear(D, N)(x)      # x에 따라 B가 달라짐
C = nn.Linear(D, N)(x)      # x에 따라 C가 달라짐
delta = nn.Linear(D, D)(x)  # x에 따라 step size도 달라짐

이것이 Selective State Space의 의미입니다. 입력에 따라 어떤 정보를 상태에 저장하고 어떤 정보를 무시할지 선택합니다.

직관적 이해

Selection Mechanism을 직관적으로 이해하면:

  • Δ(delta): "이 토큰을 얼마나 중요하게 볼 것인가" — 큰 Δ는 현재 입력을 강하게 반영, 작은 Δ는 이전 상태를 유지
  • B: "이 입력의 어떤 부분을 상태에 기록할 것인가"
  • C: "상태에서 어떤 부분을 출력으로 꺼낼 것인가"

이는 Transformer의 attention이 하는 역할과 유사하지만, O(N) 복잡도로 수행합니다.

하드웨어 최적화: Selective Scan 알고리즘

Selective SSM은 입력 의존적이므로 기존 SSM의 효율적인 합성곱(convolution) 트릭을 사용할 수 없습니다. Mamba는 이 문제를 하드웨어 인지 알고리즘으로 해결합니다.

문제: HBM ↔ SRAM 병목

GPU의 메모리 계층:

┌────────────────────┐
SRAM (20MB)      │  ← 매우 빠름, 매우 작음
├────────────────────┤
HBM (80GB)       │  ← 느림,└────────────────────┘

나이브한 구현은 중간 상태를 HBM에 반복 저장/로드하여 메모리 바운드가 됩니다.

해결: Kernel Fusion + Recomputation

Mamba의 Selective Scan 알고리즘:

  1. SRAM에서 한번에 계산: 이산화(discretization), 선택적 스캔, 출력 계산을 하나의 커널로 융합
  2. 중간 상태 저장 안 함: 순전파에서 (N, B, C, Δ)만 HBM에 저장하고, 역전파에서 중간 상태를 재계산
  3. 결과: FlashAttention과 유사한 IO-aware 최적화
# 간소화된 Selective Scan (개념 코드)
def selective_scan(x, A, B, C, delta):
    """
    x: (batch, length, d_model)
    Returns: (batch, length, d_model)
    """
    batch, length, d = x.shape
    n = A.shape[1]  # state dimension

    # Discretize A, B using delta
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
    deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)

    # Sequential scan (실제로는 CUDA kernel으로 병렬화)
    h = torch.zeros(batch, d, n, device=x.device)
    ys = []
    for i in range(length):
        h = deltaA[:, i] * h + deltaB_x[:, i]
        y = (h * C[:, i].unsqueeze(1)).sum(-1)
        ys.append(y)

    return torch.stack(ys, dim=1)

Mamba 블록 아키텍처

Mamba는 H3 아키텍처에서 영감을 받아 간소화된 블록을 사용합니다:

Input (x)
    ├──── Linear projection (expand) ────┐
    │                                     │
    ▼                                     ▼
  Conv1D                              SiLU (gate)
    │                                     │
    ▼                                     │
  SiLU    │                                     │
    ▼                                     │
  SSM (selective scan)    │                                     │
    ▼                                     │
    ×─────────────────────────────────────┘
  Linear projection (reduce)
  Output

PyTorch로 구현하면:

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        d_inner = d_model * expand

        # Input projection
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        # Conv1D
        self.conv1d = nn.Conv1d(
            d_inner, d_inner, d_conv,
            padding=d_conv - 1, groups=d_inner
        )

        # SSM parameters
        self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False)  # B, C, delta
        self.dt_proj = nn.Linear(1, d_inner, bias=True)

        # A parameter (structured)
        A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))

        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        batch, length, _ = x.shape

        # Dual path
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Conv + activation
        x = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)
        x = F.silu(x)

        # SSM
        y = self.ssm(x)

        # Gate and project
        y = y * F.silu(z)
        return self.out_proj(y)

실험 결과

언어 모델링 (The Pile)

모델파라미터Perplexity
Transformer++1.3B8.94
H31.3B9.19
Hyena1.3B9.07
RWKV-41.5B9.02
Mamba1.4B8.63

Mamba는 같은 크기의 Transformer보다 더 낮은 perplexity를 달성했습니다. 특히 2.8B 규모에서 Mamba는 Transformer++ 대비 약 0.3 perplexity 개선을 보였습니다.

추론 속도

시퀀스 길이TransformerMamba
2K1x1x
8K1x3x faster
64K1x5x faster
1MOOM작동

Mamba는 시퀀스 길이가 길어질수록 Transformer 대비 압도적인 속도 이점을 보여줍니다.

Selective Copying & Induction Heads

논문은 두 가지 핵심 합성 태스크로 Selection Mechanism의 효과를 검증합니다:

  1. Selective Copying: 긴 시퀀스에서 특정 토큰만 선택적으로 복사 → 기존 SSM(S4) 실패, Mamba 성공
  2. Induction Heads: "A B ... A → B" 패턴 인식 → 기존 SSM 실패, Mamba 성공

이 두 태스크는 content-aware reasoning이 필수이며, Selection Mechanism 없이는 해결이 불가능합니다.

Mamba vs Transformer: 언제 무엇을?

기준TransformerMamba
긴 시퀀스△ (O(N²))◎ (O(N))
추론 속도△ (KV Cache 증가)◎ (고정 상태)
In-context learning
학습 병렬화◎ (완전 병렬)○ (scan 병렬화)
생태계/도구

후속 연구: Mamba-2와 하이브리드

Mamba 이후 많은 발전이 있었습니다:

  • Mamba-2 (Dao & Gu, 2024): SSM과 Attention의 이론적 연결을 보여주는 State Space Duality(SSD) 프레임워크 제안
  • Jamba (AI21, 2024): Transformer + Mamba 하이브리드, 256K 컨텍스트
  • Zamba (Zyphra, 2024): 소형 하이브리드 모델에서 SOTA

마무리

Mamba는 Selection Mechanism이라는 단순하지만 강력한 아이디어로 SSM의 한계를 극복하고, Transformer에 필적하거나 능가하는 성능을 O(N) 복잡도로 달성했습니다. 하드웨어 인지 알고리즘 설계는 이론적 효율성을 실제 속도로 전환하는 핵심이었습니다.

Transformer가 여전히 주류이지만, Mamba와 하이브리드 아키텍처는 긴 시퀀스 처리가 중요한 영역에서 점점 더 존재감을 키우고 있습니다.

퀴즈

Q1: Mamba가 해결하려는 Transformer의 핵심 한계는? Self-attention의 O(N²) 시간/공간 복잡도. 시퀀스 길이가 길어질수록 연산량과 메모리가 제곱으로 증가하여 긴 시퀀스 처리가 비효율적입니다.

Q2: Selective State Space에서 "Selective"의 의미는? SSM의 파라미터(B, C, Δ)를 입력에 의존적으로 만들어, 어떤 정보를 상태에 저장하고 어떤 정보를 무시할지 선택(select)할 수 있다는 의미입니다.

Q3: 기존 SSM(S4 등)이 언어 모델링에서 부진했던 근본 원인은? Linear Time-Invariant(LTI) 특성 때문. 파라미터가 입력과 무관하게 고정되어 content-aware reasoning이 불가능했습니다.

Q4: Mamba에서 Δ(delta) 파라미터의 직관적 의미는? 현재 입력을 얼마나 중요하게 반영할지 결정하는 "step size". 큰 Δ는 현재 입력을 강하게 반영하고, 작은 Δ는 이전 상태를 유지합니다.

Q5: Mamba의 Selective Scan이 기존 SSM의 합성곱 트릭을 사용할 수 없는 이유는? 입력 의존적 파라미터 때문에 더 이상 LTI 시스템이 아니므로, 합성곱으로 변환하여 FFT로 병렬 계산하는 트릭이 적용 불가합니다.

Q6: Mamba의 하드웨어 최적화에서 FlashAttention과 유사한 핵심 전략은? IO-aware 최적화. SRAM에서 모든 계산을 융합(kernel fusion)하고, 중간 상태를 HBM에 저장하지 않고 역전파에서 재계산(recomputation)합니다.

Q7: Selective Copying 태스크로 검증하려는 능력은? 긴 시퀀스에서 특정 토큰만 선택적으로 기억하고 복사하는 content-aware reasoning 능력. 고정 파라미터 SSM은 이 태스크를 해결할 수 없습니다.

Q8: Mamba-2의 핵심 기여인 State Space Duality(SSD)란? SSM과 Attention이 수학적으로 동일한 연산의 다른 표현임을 보여주는 프레임워크. 이를 통해 두 접근법의 장점을 결합하는 이론적 기반을 제공합니다.