Skip to content

Split View: Mamba와 상태 공간 모델(SSM) 완전 정복: Transformer를 넘어서

|

Mamba와 상태 공간 모델(SSM) 완전 정복: Transformer를 넘어서

개요

2017년 "Attention is All You Need" 논문 이후, Transformer는 자연어 처리부터 이미지, 코드, 오디오까지 거의 모든 시퀀스 모델링 태스크를 지배해왔습니다. 그러나 Transformer는 근본적인 약점이 있습니다. 바로 셀프 어텐션의 이차 복잡도(O(n^2)) 문제입니다.

2023년 Albert Gu와 Tri Dao가 발표한 Mamba는 이 문제를 우아하게 해결하며 딥러닝 커뮤니티에 큰 충격을 주었습니다. Mamba는 상태 공간 모델(State Space Model, SSM)이라는 고전 제어 이론에서 온 개념을 현대 딥러닝에 접목하여, 선형 시간 복잡도로 긴 시퀀스를 처리할 수 있게 했습니다.

이 가이드에서는 SSM의 수학적 기초부터 Mamba의 핵심 혁신, 그리고 실전 구현까지 완전히 다룹니다.


1. Transformer의 한계와 SSM의 등장

셀프 어텐션의 O(n^2) 복잡도 문제

Transformer의 핵심은 셀프 어텐션(Self-Attention) 메커니즘입니다. 길이 n인 시퀀스에 대해 어텐션 행렬을 계산할 때, 모든 토큰 쌍의 관계를 계산하므로 시간과 메모리 복잡도가 **O(n^2)**이 됩니다.

어텐션 행렬 크기: n × n
= 1,000 토큰 → 1,000,000 원소
= 10,000 토큰 → 100,000,000 원소
= 100,000 토큰 → 10,000,000,000 원소 (불가능!)

이 문제로 인해 실제 LLM들은 컨텍스트 윈도우 크기에 제약을 받습니다. GPT-4는 128K 토큰을 지원하지만, 이를 위해 엄청난 계산 비용과 최적화 기법이 필요합니다.

긴 시퀀스 처리의 어려움

긴 문서 요약, 코드 베이스 이해, 장시간 대화 유지 등 실제 애플리케이션에서는 매우 긴 컨텍스트가 필요합니다. 하지만 Transformer의 이차 복잡도는 실용적 사용을 어렵게 만듭니다.

예를 들어, 전체 책 한 권(약 100,000 단어 = 130,000 토큰)을 한 번에 처리하려면, 표준 Transformer는 현재 하드웨어로 처리 불가능에 가까운 메모리를 요구합니다.

순환 모델(RNN, LSTM)의 한계

순환 신경망(RNN)과 LSTM은 이론적으로 O(n) 복잡도를 가집니다. 각 타임스텝에서 고정된 크기의 숨겨진 상태(hidden state)를 유지하기 때문입니다. 그러나 RNN/LSTM은 다음 문제가 있습니다:

  • 기울기 소실/폭발 (Vanishing/Exploding Gradient): 긴 시퀀스에서 역전파 시 기울기가 사라지거나 폭발
  • 병렬화 불가능: 각 타임스텝이 이전 상태에 의존하므로 GPU 병렬화가 어려움
  • 장기 의존성 포착 어려움: 멀리 떨어진 정보를 연결하기 어려움

SSM이 제시하는 해결책

상태 공간 모델(SSM)은 이 두 접근법의 장점을 결합합니다:

  • 훈련 시: 합성곱(Convolution)으로 완전 병렬화 → Transformer에 버금가는 훈련 효율
  • 추론 시: 순환(Recurrence)으로 O(1) 메모리와 O(n) 계산 → RNN처럼 효율적인 추론
  • 긴 시퀀스: 선형 복잡도로 매우 긴 시퀀스 처리 가능

2. 상태 공간 모델(SSM) 수학적 기초

연속 시간 SSM

SSM의 기원은 1960년대 Rudolf Kalman의 제어 이론으로 거슬러 올라갑니다. 연속 시간 선형 시스템은 다음과 같이 표현됩니다:

x'(t) = A·x(t) + B·u(t)
y(t)  = C·x(t) + D·u(t)

여기서:

  • u(t): 입력 신호 (input)
  • x(t): 상태 벡터 (latent state), 크기 N
  • y(t): 출력 신호 (output)
  • A: N×N 상태 전이 행렬 (state transition matrix)
  • B: N×1 입력 행렬 (input projection)
  • C: 1×N 출력 행렬 (output projection)
  • D: 직접 전달 항 (skip connection, 보통 0으로 설정)

이 시스템은 입력 u(t)를 받아 내부 상태 x(t)를 업데이트하고 출력 y(t)를 생성합니다. 상태 x(t)는 과거 정보의 "요약"이라고 볼 수 있습니다.

이산화 (Discretization)

연속 시간 SSM을 디지털 컴퓨터에서 사용하려면 이산화(discretization)가 필요합니다. 샘플링 간격을 Delta라 할 때, 두 가지 주요 이산화 방법이 있습니다.

Zero-Order Hold (ZOH):

A_bar = exp(Delta · A)
B_bar = (Delta · A)^(-1) · (exp(Delta · A) - I) · Delta · B

Bilinear (Tustin) 변환:

A_bar = (I - Delta/2 · A)^(-1) · (I + Delta/2 · A)
B_bar = (I - Delta/2 · A)^(-1) · Delta · B

이산화 후 시스템은:

x[t] = A_bar · x[t-1] + B_bar · u[t]
y[t] = C · x[t]

이 형태는 선형 순환 (Linear Recurrence)으로, 각 타임스텝에서 상태를 업데이트합니다.

합성곱 커널로서의 SSM

SSM의 강력한 점은 훈련 시 합성곱으로 계산 가능하다는 것입니다. 초기 상태 x[0] = 0 으로 시작하면:

y[0] = C · B_bar · u[0]
y[1] = C · A_bar · B_bar · u[0] + C · B_bar · u[1]
y[2] = C · A_bar^2 · B_bar · u[0] + C · A_bar · B_bar · u[1] + C · B_bar · u[2]
...

이를 합성곱 커널 K로 표현하면:

K = (C·B_bar, C·A_bar·B_bar, C·A_bar^2·B_bar, ...)
y = K * u  (합성곱)

이 커널 K는 병렬로 계산할 수 있으며, FFT를 사용하면 O(n log n)으로 계산 가능합니다.

이것이 SSM의 핵심 이중성입니다:

  • 추론: 순환(Recurrence)으로 O(1) 메모리
  • 훈련: 합성곱(Convolution)으로 O(n log n) 병렬 계산

3. S4 (Structured State Space Sequence Model)

S4는 2021년 Albert Gu 등이 발표한 논문으로, SSM을 딥러닝에 실용적으로 적용한 첫 번째 중요한 작업입니다.

HiPPO 행렬 초기화

S4의 핵심 기여 중 하나는 HiPPO(High-order Polynomial Projection Operators) 행렬 초기화입니다. 단순히 무작위로 A를 초기화하면 기울기 소실 문제가 발생합니다. HiPPO는 과거 입력을 다항식으로 근사하도록 설계된 특별한 A 행렬을 제공합니다.

HiPPO-LegS (Legendre 다항식 기반):

A[n,k] = -sqrt((2n+1)(2k+1))  if n > k
A[n,k] = -(n+1)               if n == k
A[n,k] = 0                    if n < k

이 초기화는 상태 x[t]가 모든 과거 입력 u[0..t]에 대한 최적 다항식 근사를 유지하게 합니다.

구조화된 행렬 A (DPLR)

A 행렬의 합성곱 커널을 효율적으로 계산하기 위해, S4는 A를 Diagonal Plus Low-Rank (DPLR) 형태로 표현합니다:

A = Λ - P·Q^T

여기서 Λ는 대각 행렬이고 P, Q는 낮은 랭크의 벡터입니다. 이 구조를 이용하면 커널 계산을 O(N)으로 줄일 수 있습니다.

효율적 계산

import torch
import torch.nn as nn
import numpy as np

class S4Layer(nn.Module):
    """S4 레이어의 간소화된 구현"""
    def __init__(self, d_model, d_state=64, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # HiPPO-LegS 초기화
        A = self._make_hippo(d_state)
        # DPLR 분해
        self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32))
        self.B = nn.Parameter(torch.randn(d_state, 1) * 0.01)
        self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)
        self.log_delta = nn.Parameter(torch.zeros(d_model))
        self.D = nn.Parameter(torch.ones(d_model))

    def _make_hippo(self, N):
        """HiPPO-LegS 행렬 생성"""
        P = np.sqrt(1 + 2 * np.arange(N))
        A = P[:, np.newaxis] * P[np.newaxis, :]
        A = np.tril(A) - np.diag(np.arange(N))
        return -A

    def discretize(self, A, B, C, delta):
        """ZOH 이산화"""
        dA = torch.matrix_exp(delta.unsqueeze(-1) * A)
        dB = torch.linalg.solve(A, (dA - torch.eye(A.shape[0])) @ B)
        return dA, dB, C

    def forward(self, u):
        # u: (batch, seq_len, d_model)
        delta = torch.exp(self.log_delta)
        # 이산화 및 합성곱 계산
        # 실제 구현은 더 복잡하지만 여기서는 개념 설명
        return u  # 간소화


4. H3 (Hungry Hungry Hippos)

2022년 발표된 H3는 S4를 언어 모델링에 더 적합하게 개선한 모델입니다. 제목은 HiPPO의 줄임말을 유머러스하게 표현한 것입니다.

S4와의 차이점

S4는 긴 범위 의존성을 잘 포착하지만, 언어 모델링에서 중요한 토큰 간 상호작용이 부족했습니다. 어텐션은 "이 단어가 저 단어와 얼마나 관련 있나?"를 직접 계산하지만, S4는 순수한 순환으로 이를 포착하기 어렵습니다.

게이팅 메커니즘

H3는 두 개의 SSM을 사용하고 그 사이에 게이팅을 추가합니다:

class H3Layer(nn.Module):
    """H3 레이어"""
    def __init__(self, d_model, d_state=64):
        super().__init__()
        # 두 개의 SSM (shift SSM + diagonal SSM)
        self.shift_ssm = S4Layer(d_model, d_state)
        self.diag_ssm = S4Layer(d_model, d_state)

        # 프로젝션
        self.Q_proj = nn.Linear(d_model, d_model)
        self.K_proj = nn.Linear(d_model, d_model)
        self.V_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # Q, K, V 프로젝션
        q = self.Q_proj(x)
        k = self.K_proj(x)
        v = self.V_proj(x)

        # SSM 1: K에 Shift SSM 적용
        k_ssm = self.shift_ssm(k)

        # 곱셈 게이팅: Q ⊙ K_ssm
        gated = q * k_ssm

        # SSM 2: V에 Diagonal SSM 적용 후 gated와 곱
        v_ssm = self.diag_ssm(v)
        output = gated * v_ssm

        return self.out_proj(output)

언어 모델링 개선

H3는 GPT-2 크기의 언어 모델에서 Transformer에 근접한 성능을 달성하면서도, 추론 시 훨씬 빠른 속도를 보였습니다. 이는 SSM이 언어 모델링에서 실용적임을 보여주는 중요한 이정표였습니다.


5. Mamba (Selective State Space Model)

2023년 12월, Albert Gu와 Tri Dao가 발표한 Mamba는 SSM 연구의 가장 중요한 진전입니다. Mamba의 핵심 혁신은 **선택적 메커니즘(Selective Mechanism)**입니다.

핵심 혁신: 선택적 메커니즘

기존 S4와 H3의 근본적인 한계는 행렬 A, B, C가 **입력에 독립적(input-independent)**이었다는 점입니다. 합성곱으로 훈련하기 위해서는 커널이 고정되어야 했기 때문입니다.

Mamba는 이 제약을 깨고 **S6 (Selective SSM)**를 도입합니다: B, C, Delta 행렬을 입력 u(t)의 함수로 만들었습니다.

B(t) = Linear_B(u(t))   # 입력에 따라 달라지는 B
C(t) = Linear_C(u(t))   # 입력에 따라 달라지는 C
Delta(t) = softplus(Linear_Delta(u(t)))  # 입력에 따라 달라지는 Delta

이렇게 하면 모델이 어떤 정보를 상태에 저장하고 어떤 정보를 무시할지를 입력에 따라 동적으로 결정할 수 있습니다.

직관적으로:

  • Delta가 크면 → 현재 입력이 상태에 강하게 반영됨 (정보 기억)
  • Delta가 작으면 → 상태 변화가 적음 (정보 무시)

이는 LSTM의 게이트(input gate, forget gate)와 유사한 역할을 하지만, 연속 시간 SSM 프레임워크 안에서 이루어집니다.

Hardware-Aware Algorithm

선택적 메커니즘을 도입하면 B, C가 입력에 의존하므로 더 이상 합성곱으로 계산할 수 없습니다. 이는 심각한 계산 비효율을 초래할 수 있습니다.

Mamba는 이를 해결하기 위해 Hardware-Aware Algorithm을 사용합니다. 이는 GPU의 메모리 계층 구조(HBM vs SRAM)를 활용한 Kernel Fusion 기법입니다:

문제: 각 타임스텝의 순환 계산에서 중간 상태를 HBM(느린 메모리)에 저장하면 메모리 대역폭 병목이 발생

해결:

  • 모든 중간 계산을 빠른 SRAM(온칩 메모리)에서 수행
  • 완전한 합성곱이 아닌 Parallel Scan 알고리즘 사용
  • 최종 출력만 HBM에 저장
# Parallel Scan의 개념 (실제 구현은 CUDA로)
def parallel_scan(gates, tokens):
    """
    선형 순환을 병렬로 계산
    x[t] = gates[t] * x[t-1] + tokens[t]

    이진 트리 구조로 O(log n) 병렬 깊이 달성
    """
    n = len(tokens)
    # 상향(up-sweep) 단계
    log_n = int(np.log2(n))
    for d in range(log_n):
        step = 2 ** (d + 1)
        for i in range(step - 1, n, step):
            gates[i] = gates[i] * gates[i - 2**d]
            tokens[i] = gates[i - 2**d] * tokens[i - 2**d] + tokens[i]
    # 하향(down-sweep) 단계 (생략)
    return tokens

Mamba 블록 구조

입력 x (B, L, D)
    ├──────────────────────┐
    │                      │
  Linear(DED)          Linear(DED)
  + SiLU activation        │
SSM (S6)
    │                      │
    └─────── ⊙ ────────────┘
    (element-wise 곱셈)
  Linear(EDD)
  출력 y (B, L, D)

여기서 E는 확장 비율(보통 2), D는 모델 차원입니다.

완전한 Mamba 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class MambaBlock(nn.Module):
    """
    Mamba Block 구현
    논문: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
    """
    def __init__(
        self,
        d_model,       # 모델 차원 D
        d_state=16,    # SSM 상태 차원 N
        d_conv=4,      # 합성곱 커널 크기
        expand=2,      # 확장 비율 E
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        bias=False,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)

        if dt_rank == "auto":
            self.dt_rank = max(1, int(d_model / 16))
        else:
            self.dt_rank = dt_rank

        # 입력 프로젝션 (D → 2*ED, 두 브랜치를 한번에 계산)
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)

        # 로컬 합성곱 (depthwise conv)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=True,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,  # causal padding
        )

        # SSM 파라미터 프로젝션
        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        # A 초기화 (HiPPO 기반)
        A = repeat(
            torch.arange(1, self.d_state + 1),
            "n -> d n",
            d=self.d_inner
        )
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True

        # D (skip connection)
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.D._no_weight_decay = True

        # 출력 프로젝션
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)

    def forward(self, hidden_states):
        """
        hidden_states: (B, L, D)
        Returns: (B, L, D)
        """
        batch, seqlen, dim = hidden_states.shape

        # 입력 프로젝션
        xz = self.in_proj(hidden_states)
        x, z = xz.chunk(2, dim=-1)  # 각각 (B, L, ED)

        # 합성곱 (causal 1D conv)
        x = rearrange(x, "b l d -> b d l")
        x = self.conv1d(x)[:, :, :seqlen]  # causal trimming
        x = rearrange(x, "b d l -> b l d")
        x = F.silu(x)

        # SSM
        y = self.ssm(x)

        # 게이팅 (z 브랜치와 element-wise 곱)
        y = y * F.silu(z)

        # 출력 프로젝션
        output = self.out_proj(y)
        return output

    def ssm(self, x):
        """선택적 SSM (S6) 계산"""
        d_in, n = self.d_inner, self.d_state

        # A 행렬 (-exp(-A_log)로 항상 음수 보장 = 안정성)
        A = -torch.exp(self.A_log.float())  # (ED, N)

        # x_proj: Delta, B, C 계산
        x_dbl = self.x_proj(x)  # (B, L, dt_rank + 2N)
        delta, B, C = x_dbl.split(
            [self.dt_rank, n, n], dim=-1
        )

        # Delta: softplus로 양수 보장
        delta = F.softplus(self.dt_proj(delta))  # (B, L, ED)

        # ZOH 이산화 및 선택적 스캔
        y = self.selective_scan(x, delta, A, B, C, self.D)
        return y

    def selective_scan(self, u, delta, A, B, C, D):
        """
        선택적 스캔 알고리즘
        실제 구현에서는 mamba_ssm의 CUDA 커널 사용
        여기서는 순수 PyTorch로 개념 설명
        """
        b, l, d_in = u.shape
        n = A.shape[1]

        # 이산화
        # delta: (B, L, ED), A: (ED, N) → dA: (B, L, ED, N)
        deltaA = torch.exp(
            torch.einsum("bld,dn->bldn", delta, A)
        )
        # delta: (B, L, ED), B: (B, L, N), u: (B, L, ED)
        # → deltaB_u: (B, L, ED, N)
        deltaB_u = torch.einsum("bld,bln,bld->bldn", delta, B, u)

        # 순환 스캔
        x = torch.zeros(b, d_in, n, device=u.device, dtype=u.dtype)
        ys = []
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = torch.einsum("bdn,bn->bd", x, C[:, i, :])
            ys.append(y)
        y = torch.stack(ys, dim=1)  # (B, L, ED)

        # D (skip connection)
        y = y + u * D

        return y


class MambaModel(nn.Module):
    """스택된 Mamba 블록으로 구성된 시퀀스 모델"""
    def __init__(self, d_model, n_layers, d_state=16, expand=2):
        super().__init__()
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state=d_state, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        """x: (B, L, D)"""
        for layer in self.layers:
            x = x + layer(self.norm(x))  # Pre-norm residual
        return x

6. Mamba 2

2024년 5월, Tri Dao와 Albert Gu는 Mamba 2를 발표하면서 이론적 기반을 더욱 강화했습니다.

Mamba 1과의 차이

Mamba 1의 선택적 스캔은 각 타임스텝에서 순차적 계산이 필요했습니다. Mamba 2는 **State Space Duality (SSD)**를 발견하여 이를 더욱 효율적으로 만들었습니다.

State Space Duality (SSD)

Mamba 2의 핵심 통찰: 특정 구조를 가진 SSM은 **세미 분리 가능 행렬(Semi-Separable Matrix)**로 표현되며, 이는 특정 형태의 어텐션과 동일합니다.

이 수학적 동치(duality)를 통해:

  1. SSM 계산을 행렬 곱셈 형태로 표현 가능
  2. 고도로 최적화된 BLAS 라이브러리 활용 가능
  3. Tensor Core 활용으로 GPU 효율 극대화
# SSD 연산의 개념
# SSM 순환: x[t] = A[t] * x[t-1] + B[t] * u[t]
#            y[t] = C[t]^T * x[t]
#
# 이를 청크(chunk) 단위로 처리:
# - 청크 내부: 행렬 곱셈으로 병렬 계산
# - 청크 간: 순환으로 상태 전파

class Mamba2Block(nn.Module):
    """Mamba 2 블록"""
    def __init__(self, d_model, d_state=64, n_heads=8, chunk_size=64):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.n_heads = n_heads
        self.chunk_size = chunk_size
        self.d_head = d_model // n_heads

        # 멀티헤드 구조
        self.norm = nn.LayerNorm(d_model)
        self.in_proj = nn.Linear(d_model, d_model * 2 + d_state * 2 + n_heads)
        self.out_proj = nn.Linear(d_model, d_model)

        # A 파라미터 (헤드별)
        self.A_log = nn.Parameter(torch.randn(n_heads))

    def forward(self, x):
        """x: (B, L, D)"""
        # 실제 구현은 mamba_ssm 패키지의 CUDA 커널 사용
        return x

다중 헤드 구조

Mamba 2는 다중 헤드 SSM을 도입하여 Transformer의 Multi-Head Attention과 구조적으로 유사해졌습니다. 이는:

  • 더 나은 표현력
  • 어텐션과의 이론적 연결
  • 하이브리드 아키텍처 설계 용이성

Transformer와의 통합 가능성

SSD 이론은 특정 SSM이 특정 어텐션과 동일함을 보여줍니다. 이는 Mamba와 Transformer를 혼합하는 하이브리드 아키텍처의 이론적 근거를 제공합니다.


7. 하이브리드 아키텍처

MambaFormer

MambaFormer는 Mamba 블록과 Transformer 블록을 인터리빙(interleaving)합니다:

레이어 1: Mamba Block  (로컬 패턴 포착)
레이어 2: Attention    (전역 의존성 포착)
레이어 3: Mamba Block
레이어 4: Attention
...

이 구조는 각 컴포넌트의 장점을 활용합니다:

  • Mamba: 효율적인 시퀀스 처리, 로컬 패턴
  • Attention: 선택적 정보 검색, 전역 의존성

Jamba (SSM + Transformer + MoE)

2024년 AI21 Labs가 발표한 Jamba는 더 복잡한 하이브리드입니다:

Jamba = Mamba + Transformer + Mixture of Experts

아키텍처:
- 52B 파라미터 (활성화: 12B)
- 레이어 비율: Attention 1 : Mamba 7
- 일부 레이어에 MoE 적용
- 256K 컨텍스트 윈도우 지원

Jamba는 동일 크기 Transformer 대비:

  • 추론 처리량 3배 향상
  • 긴 컨텍스트에서 메모리 효율 대폭 개선
# Jamba 스타일 하이브리드 블록 (개념)
class JambaLayer(nn.Module):
    def __init__(self, d_model, layer_idx, attn_every_n=8):
        super().__init__()
        self.use_attention = (layer_idx % attn_every_n == 0)

        if self.use_attention:
            self.mixer = nn.MultiheadAttention(d_model, num_heads=8)
        else:
            self.mixer = MambaBlock(d_model)

        # MoE (일부 레이어에만)
        self.use_moe = (layer_idx % 2 == 0)
        if self.use_moe:
            self.ffn = MixtureOfExperts(d_model, n_experts=16)
        else:
            self.ffn = nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.SiLU(),
                nn.Linear(d_model * 4, d_model)
            )

    def forward(self, x):
        x = x + self.mixer(x)
        x = x + self.ffn(x)
        return x

RWKV (선형 어텐션)

RWKV는 Transformer와 RNN의 하이브리드입니다. "Receptance Weighted Key Value"의 약자로:

- 훈련: Transformer처럼 병렬화 (행렬 형태)
- 추론: RNN처럼 순환 (O(1) 상태)
- 어텐션 없이 토큰 믹싱

최신 버전(RWKV-6, RWKV-7)은 Mamba와 경쟁하는 성능을 보입니다.

RetNet (Retentive Network)

Microsoft의 RetNet은 "Training Parallelism, Inference Efficiency, Competitive Performance"를 동시에 달성하는 것을 목표로 합니다:

3가지 계산 패러다임:
1. Parallel: 훈련 시 O(n^2) 병렬 계산 (Transformer 대비 낮은 상수)
2. Recurrent: 추론 시 O(1) 메모리
3. Chunkwise Recurrent: 균형 잡힌 중간 방법

8. Mamba 성능 비교

추론 속도 (선형 스케일)

Mamba의 가장 큰 강점은 시퀀스 길이에 따른 선형 스케일입니다.

시퀀스 길이 2K일 때 성능을 1로 정규화:

시퀀스 길이   Transformer   Mamba
1K            0.5x          0.5x
2K            1.0x          1.0x
4K            3.5x          2.0x    ← Transformer 대비 1.75배 빠름
8K            13x           4.0x    ← Transformer 대비 3.25배 빠름
16K           50x           8.0x    ← Transformer 대비 6.25배 빠름
100K          ~1800x        ~50x    ← Transformer 대비 36배 빠름

메모리 효율성

추론 시 상태(state) 크기 비교:

모델           추론 시 메모리 (1K 토큰당)
Transformer    O(n) KV Cache
Mamba          O(1) SSM 상태 (고정 크기)

: 130M 파라미터 모델, 1M 토큰 시퀀스
- Transformer: ~16GB KV Cache
- Mamba: ~1MB 상태 (상수!)

긴 시퀀스 태스크

Long Range Arena (LRA) 벤치마크 (시퀀스 길이 1K-16K):

모델        ListOps  Text  Retrieval  Image  Path-X  평균
Transformer  36.4   65.0    57.5      42.4    0.0    40.3
LSTM         35.9   63.7    65.0      43.3    0.0    41.6
S4           59.6   86.8    90.9      88.7   86.1    82.4
Mamba        ~S4 수준의 성능, 더 빠른 처리

Mamba는 특히 선택적 복사(Selective Copy), 유도 헤드(Induction Head) 같은 합성 벤치마크에서 S4를 크게 앞서며, 이는 실제 언어 모델링 능력과 더 관련이 깊습니다.


9. Mamba 활용 분야

자연어 처리

Mamba 기반 언어 모델들이 속속 등장하고 있습니다:

  • MambaChat: 대화형 AI 어시스턴트
  • Falcon Mamba: TII가 오픈소스로 발표한 7B Mamba 모델
  • CodeMamba: 코드 생성 특화

긴 문서 처리, 요약, 번역에서 Transformer 대비 효율적입니다.

바이오인포매틱스

게놈 시퀀스 분석은 매우 긴 시퀀스(수백만 bp)를 다루므로 Mamba가 특히 유리합니다:

  • Caduceus: 긴 DNA 시퀀스 모델링
  • Hyena: 긴 시퀀스 DNA/단백질 분석
  • 단백질 구조 예측에서의 활용 가능성
# 바이오 시퀀스를 위한 Mamba 활용 예시
from mamba_ssm import MambaLMHeadModel

# DNA 시퀀스 모델링 (A, T, G, C, N 토큰)
DNA_VOCAB = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}
VOCAB_SIZE = len(DNA_VOCAB)

# 매우 긴 시퀀스도 효율적으로 처리
model = MambaLMHeadModel.from_pretrained(
    "state-spaces/mamba-130m",
    device="cuda",
    dtype=torch.float16
)

시계열 분석

금융, 기상, IoT 센서 데이터 등 긴 시계열에서 Mamba는 강점을 보입니다:

  • TimeMamba: 긴 시계열 예측
  • MambaMixer: 다변량 시계열 모델링
  • S4/S5 기반 시계열 모델들의 발전

이미지 처리 (VMamba)

VMamba는 Mamba를 2D 이미지 처리에 확장합니다:

# VMamba의 핵심: 2D 선택적 스캔
# 이미지를 4방향으로 스캔하여 2D 구조 포착

# 방향:
# 1. 좌→우, 상→하 (일반적 래스터 스캔)
# 2. 우→좌, 하→상 (역방향)
# 3. 상→하, 좌→우 (열 우선)
# 4. 하→상, 우→좌 (역방향)

class VMambaBlock(nn.Module):
    """VMamba: Visual Mamba Block"""
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        # 4방향 SSM
        self.ssms = nn.ModuleList([
            MambaBlock(d_model, d_state) for _ in range(4)
        ])
        self.out_proj = nn.Linear(d_model * 4, d_model)

    def forward(self, x):
        """x: (B, H, W, D) - 이미지 패치 임베딩"""
        b, h, w, d = x.shape
        x_flat = x.view(b, h*w, d)

        outputs = []
        # 4방향 스캔
        for i, ssm in enumerate(self.ssms):
            if i == 0:  # 정방향
                seq = x_flat
            elif i == 1:  # 역방향
                seq = x_flat.flip(1)
            elif i == 2:  # 열 우선
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d)
            else:  # 열 우선 역방향
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d).flip(1)

            out = ssm(seq)
            if i % 2 == 1:
                out = out.flip(1)
            outputs.append(out)

        # 4방향 결과 합산
        combined = torch.cat(outputs, dim=-1)
        return self.out_proj(combined).view(b, h, w, d)

10. 실전 사용

mamba-ssm 패키지 설치

# 필요 패키지
pip install torch torchvision torchaudio
pip install causal-conv1d>=1.2.0
pip install mamba-ssm

# 또는 소스에서 설치 (최신 기능)
git clone https://github.com/state-spaces/mamba
cd mamba
pip install -e ".[dev]"

MambaLMHeadModel 사용

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer

# 사전 학습된 모델 로드
model_name = "state-spaces/mamba-2.8b-slimpj"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = MambaLMHeadModel.from_pretrained(
    model_name,
    device="cuda",
    dtype=torch.bfloat16
)
model.eval()

# 텍스트 생성
def generate_text(prompt, max_new_tokens=200, temperature=0.7):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated = output[0][input_ids.shape[1]:]
    return tokenizer.decode(generated, skip_special_tokens=True)

# 예시
prompt = "State Space Models are powerful because"
result = generate_text(prompt)
print(result)

파인튜닝 예제

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
        }

def finetune_mamba(
    model_name="state-spaces/mamba-130m",
    texts=None,
    num_epochs=3,
    learning_rate=1e-4,
    batch_size=8,
):
    """Mamba 파인튜닝"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token = tokenizer.eos_token

    # 모델 로드
    model = MambaLMHeadModel.from_pretrained(
        model_name,
        device=device,
        dtype=torch.bfloat16
    )

    # 데이터셋
    dataset = TextDataset(texts, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 옵티마이저 (Mamba에는 AdamW 권장)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=0.1
    )

    # 스케줄러
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=total_steps // 10,
        num_training_steps=total_steps
    )

    # 훈련 루프
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            input_ids = batch['input_ids'].to(device)

            # 언어 모델링: 다음 토큰 예측
            outputs = model(input_ids)
            logits = outputs.logits

            # Shift for next token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()

            loss = nn.CrossEntropyLoss()(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}, Step {batch_idx}: Loss = {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} 완료: 평균 Loss = {avg_loss:.4f}")

    return model

커스텀 Mamba 모델 구성

from mamba_ssm import Mamba
from mamba_ssm.models.config_mamba import MambaConfig

# 커스텀 설정
config = MambaConfig(
    d_model=1024,
    n_layer=48,
    vocab_size=50280,
    d_state=16,
    d_conv=4,
    expand=2,
    dt_rank="auto",
    dt_min=0.001,
    dt_max=0.1,
    dt_init="random",
    dt_scale=1.0,
    dt_init_floor=1e-4,
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    pad_vocab_size_multiple=8,
)

# 모델 생성 (약 1.4B 파라미터)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.bfloat16)
print(f"파라미터 수: {sum(p.numel() for p in model.parameters()):,}")

마무리: Mamba의 미래

Mamba는 딥러닝 시퀀스 모델링 분야에 혁명적인 변화를 가져왔습니다. 핵심 기여를 정리하면:

  1. 선택적 메커니즘: 입력에 따라 동적으로 변하는 SSM 파라미터
  2. Hardware-Aware 설계: GPU 메모리 계층을 활용한 효율적 계산
  3. 이중 표현: 훈련 시 병렬화, 추론 시 순환으로 최적화
  4. 선형 복잡도: 시퀀스 길이에 선형적인 계산 및 메모리 복잡도

아직 해결해야 할 과제도 있습니다:

  • In-context Learning에서 Transformer보다 다소 부족한 성능
  • 매우 큰 스케일에서의 검증 필요
  • 어텐션 기반 모델과의 명확한 우위 입증

그러나 Mamba와 SSM 계열의 모델들은 긴 시퀀스 처리, 실시간 추론, 엣지 디바이스 배포 등 Transformer가 어려움을 겪는 분야에서 점점 더 중요한 역할을 하게 될 것입니다.

참고 자료

  • Mamba 논문: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023)
  • Mamba 2 논문: "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" (Dao & Gu, 2024)
  • S4 논문: "Efficiently Modeling Long Sequences with Structured State Spaces" (Gu et al., 2021)
  • H3 논문: "Hungry Hungry Hippos: Towards Language Modeling with State Space Models" (Fu et al., 2022)
  • Mamba GitHub: https://github.com/state-spaces/mamba
  • Jamba 논문: "Jamba: A Hybrid Transformer-Mamba Language Model" (AI21 Labs, 2024)

Mamba and State Space Models Complete Guide: Beyond Transformers

Overview

Since the landmark "Attention is All You Need" paper in 2017, Transformers have dominated virtually every sequence modeling task — from natural language processing to images, code, and audio. However, Transformers have a fundamental weakness: the quadratic complexity O(n^2) of self-attention.

In December 2023, Mamba, published by Albert Gu and Tri Dao, elegantly solved this problem and sent shockwaves through the deep learning community. Mamba adapts State Space Models (SSMs) — a concept from classical control theory — to modern deep learning, enabling linear-time processing of long sequences.

This guide covers everything from the mathematical foundations of SSMs to Mamba's core innovations and practical implementations.


1. The Limitations of Transformers and the Rise of SSMs

The O(n^2) Complexity Problem of Self-Attention

The core of Transformers is the Self-Attention mechanism. For a sequence of length n, computing the attention matrix requires computing relationships between all pairs of tokens, resulting in both time and memory complexity of O(n^2).

Attention matrix size: n × n
= 1,000 tokens   → 1,000,000 elements
= 10,000 tokens  → 100,000,000 elements
= 100,000 tokens → 10,000,000,000 elements (infeasible!)

Due to this problem, real-world LLMs face constraints on context window size. GPT-4 supports 128K tokens, but this requires enormous computational cost and optimization techniques.

The Difficulty of Processing Long Sequences

Real applications — such as long document summarization, understanding entire codebases, and maintaining long-term conversations — require very long context. But Transformer's quadratic complexity makes practical use challenging.

For example, processing an entire book (approximately 100,000 words = 130,000 tokens) at once would require memory that is virtually infeasible on current hardware with standard Transformers.

The Limitations of Recurrent Models (RNN, LSTM)

Recurrent neural networks (RNNs) and LSTMs theoretically have O(n) complexity because they maintain a fixed-size hidden state at each timestep. However, RNN/LSTM have the following problems:

  • Vanishing/Exploding Gradients: Gradients disappear or explode during backpropagation on long sequences
  • Non-parallelizable: Each timestep depends on the previous state, making GPU parallelization difficult
  • Difficulty capturing long-range dependencies: Hard to connect information far apart in the sequence

The Solution SSMs Offer

State Space Models (SSMs) combine the advantages of both approaches:

  • During training: Full parallelization via convolution → training efficiency comparable to Transformers
  • During inference: Recurrence with O(1) memory and O(n) computation → efficient inference like RNNs
  • Long sequences: Linear complexity enables processing of very long sequences

2. Mathematical Foundations of State Space Models

Continuous-Time SSM

The origins of SSMs trace back to Rudolf Kalman's control theory in the 1960s. A continuous-time linear system is expressed as:

x'(t) = A·x(t) + B·u(t)
y(t)  = C·x(t) + D·u(t)

Where:

  • u(t): input signal
  • x(t): state vector (latent state), size N
  • y(t): output signal
  • A: N×N state transition matrix
  • B: N×1 input projection matrix
  • C: 1×N output projection matrix
  • D: direct feedthrough term (skip connection, usually set to 0)

This system receives input u(t), updates the internal state x(t), and produces output y(t). The state x(t) can be thought of as a "summary" of past information.

Discretization

To use continuous-time SSMs on digital computers, discretization is required. With a sampling interval Delta, there are two main discretization methods.

Zero-Order Hold (ZOH):

A_bar = exp(Delta · A)
B_bar = (Delta · A)^(-1) · (exp(Delta · A) - I) · Delta · B

Bilinear (Tustin) Transform:

A_bar = (I - Delta/2 · A)^(-1) · (I + Delta/2 · A)
B_bar = (I - Delta/2 · A)^(-1) · Delta · B

After discretization, the system becomes:

x[t] = A_bar · x[t-1] + B_bar · u[t]
y[t] = C · x[t]

This form is a Linear Recurrence, updating the state at each timestep.

SSM as a Convolution Kernel

The powerful property of SSMs is that they can be computed as convolutions during training. Starting with initial state x[0] = 0:

y[0] = C · B_bar · u[0]
y[1] = C · A_bar · B_bar · u[0] + C · B_bar · u[1]
y[2] = C · A_bar^2 · B_bar · u[0] + C · A_bar · B_bar · u[1] + C · B_bar · u[2]
...

Expressing this as convolution kernel K:

K = (C·B_bar, C·A_bar·B_bar, C·A_bar^2·B_bar, ...)
y = K * u  (convolution)

This kernel K can be computed in parallel, and using FFT, it can be computed in O(n log n).

This is the fundamental duality of SSMs:

  • Inference: Recurrence with O(1) memory
  • Training: Convolution with O(n log n) parallel computation

3. S4 (Structured State Space Sequence Model)

S4 was published by Albert Gu et al. in 2021 and is the first important work that practically applied SSMs to deep learning.

HiPPO Matrix Initialization

One of S4's key contributions is the HiPPO (High-order Polynomial Projection Operators) matrix initialization. Simply initializing A randomly leads to gradient vanishing problems. HiPPO provides a specially designed A matrix that approximates past inputs using polynomials.

HiPPO-LegS (Legendre polynomial-based):

A[n,k] = -sqrt((2n+1)(2k+1))  if n > k
A[n,k] = -(n+1)               if n == k
A[n,k] = 0                    if n < k

This initialization ensures that state x[t] maintains an optimal polynomial approximation of all past inputs u[0..t].

Structured Matrix A (DPLR)

To efficiently compute the convolution kernel of the A matrix, S4 expresses A in Diagonal Plus Low-Rank (DPLR) form:

A = Λ - P·Q^T

Where Λ is a diagonal matrix and P, Q are low-rank vectors. Using this structure, kernel computation can be reduced to O(N).

Efficient Computation

import torch
import torch.nn as nn
import numpy as np

class S4Layer(nn.Module):
    """Simplified implementation of S4 layer"""
    def __init__(self, d_model, d_state=64, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # HiPPO-LegS initialization
        A = self._make_hippo(d_state)
        # DPLR decomposition
        self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32))
        self.B = nn.Parameter(torch.randn(d_state, 1) * 0.01)
        self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)
        self.log_delta = nn.Parameter(torch.zeros(d_model))
        self.D = nn.Parameter(torch.ones(d_model))

    def _make_hippo(self, N):
        """Generate HiPPO-LegS matrix"""
        P = np.sqrt(1 + 2 * np.arange(N))
        A = P[:, np.newaxis] * P[np.newaxis, :]
        A = np.tril(A) - np.diag(np.arange(N))
        return -A

    def discretize(self, A, B, C, delta):
        """ZOH discretization"""
        dA = torch.matrix_exp(delta.unsqueeze(-1) * A)
        dB = torch.linalg.solve(A, (dA - torch.eye(A.shape[0])) @ B)
        return dA, dB, C

    def forward(self, u):
        # u: (batch, seq_len, d_model)
        delta = torch.exp(self.log_delta)
        # Discretization and convolution computation
        # Actual implementation is more complex, simplified here for concept
        return u

4. H3 (Hungry Hungry Hippos)

Published in 2022, H3 is an improvement of S4 better suited for language modeling. The title is a humorous expansion of the HiPPO acronym.

Differences from S4

S4 captures long-range dependencies well, but lacks the token-to-token interactions important in language modeling. Attention directly computes "how related is this word to that word?", but S4 finds this difficult to capture with pure recurrence.

Gating Mechanism

H3 uses two SSMs and adds gating between them:

class H3Layer(nn.Module):
    """H3 Layer"""
    def __init__(self, d_model, d_state=64):
        super().__init__()
        # Two SSMs (shift SSM + diagonal SSM)
        self.shift_ssm = S4Layer(d_model, d_state)
        self.diag_ssm = S4Layer(d_model, d_state)

        # Projections
        self.Q_proj = nn.Linear(d_model, d_model)
        self.K_proj = nn.Linear(d_model, d_model)
        self.V_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # Q, K, V projections
        q = self.Q_proj(x)
        k = self.K_proj(x)
        v = self.V_proj(x)

        # SSM 1: Apply Shift SSM to K
        k_ssm = self.shift_ssm(k)

        # Multiplicative gating: Q element-wise K_ssm
        gated = q * k_ssm

        # SSM 2: Apply Diagonal SSM to V, then multiply with gated
        v_ssm = self.diag_ssm(v)
        output = gated * v_ssm

        return self.out_proj(output)

Language Modeling Improvements

H3 achieved performance close to Transformers in GPT-2 sized language models while showing significantly faster inference speed. This was an important milestone demonstrating the practicality of SSMs for language modeling.


5. Mamba (Selective State Space Model)

In December 2023, Mamba published by Albert Gu and Tri Dao represents the most important advance in SSM research. Mamba's core innovation is the Selective Mechanism.

Core Innovation: The Selective Mechanism

The fundamental limitation of S4 and H3 was that matrices A, B, C were input-independent. The kernel had to be fixed for convolution-based training.

Mamba breaks this constraint and introduces S6 (Selective SSM): making B, C, and Delta matrices functions of input u(t).

B(t) = Linear_B(u(t))   # B that varies with input
C(t) = Linear_C(u(t))   # C that varies with input
Delta(t) = softplus(Linear_Delta(u(t)))  # Delta that varies with input

This allows the model to dynamically decide, based on input, which information to store in the state and which to discard.

Intuitively:

  • Large Delta → current input strongly influences state (remember information)
  • Small Delta → state changes little (ignore information)

This plays a role similar to LSTM's gates (input gate, forget gate), but within the continuous-time SSM framework.

Hardware-Aware Algorithm

Introducing the selective mechanism means B, C depend on input, so convolution can no longer be used for computation. This can lead to severe computational inefficiency.

Mamba solves this with a Hardware-Aware Algorithm — a kernel fusion technique leveraging GPU memory hierarchy (HBM vs SRAM):

Problem: Storing intermediate states in HBM (slow memory) during per-timestep recurrence creates memory bandwidth bottlenecks.

Solution:

  • Perform all intermediate computations in fast SRAM (on-chip memory)
  • Use Parallel Scan algorithm instead of full convolution
  • Store only the final output in HBM
# Parallel Scan concept (actual implementation in CUDA)
def parallel_scan(gates, tokens):
    """
    Compute linear recurrence in parallel
    x[t] = gates[t] * x[t-1] + tokens[t]

    Achieves O(log n) parallel depth with binary tree structure
    """
    n = len(tokens)
    # Up-sweep phase
    log_n = int(np.log2(n))
    for d in range(log_n):
        step = 2 ** (d + 1)
        for i in range(step - 1, n, step):
            gates[i] = gates[i] * gates[i - 2**d]
            tokens[i] = gates[i - 2**d] * tokens[i - 2**d] + tokens[i]
    # Down-sweep phase (omitted)
    return tokens

Mamba Block Structure

Input x (B, L, D)
    ├──────────────────────┐
    │                      │
  Linear(DED)          Linear(DED)
  + SiLU activation        │
SSM (S6)
    │                      │
    └─────── ⊙ ────────────┘
    (element-wise multiply)
  Linear(EDD)
  Output y (B, L, D)

Where E is the expansion ratio (usually 2) and D is the model dimension.

Complete Mamba Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class MambaBlock(nn.Module):
    """
    Mamba Block implementation
    Paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
    """
    def __init__(
        self,
        d_model,       # Model dimension D
        d_state=16,    # SSM state dimension N
        d_conv=4,      # Convolution kernel size
        expand=2,      # Expansion ratio E
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        bias=False,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)

        if dt_rank == "auto":
            self.dt_rank = max(1, int(d_model / 16))
        else:
            self.dt_rank = dt_rank

        # Input projection (D → 2*ED, compute both branches at once)
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)

        # Local convolution (depthwise conv)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=True,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,  # causal padding
        )

        # SSM parameter projections
        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        # A initialization (HiPPO-based)
        A = repeat(
            torch.arange(1, self.d_state + 1),
            "n -> d n",
            d=self.d_inner
        )
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True

        # D (skip connection)
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.D._no_weight_decay = True

        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)

    def forward(self, hidden_states):
        """
        hidden_states: (B, L, D)
        Returns: (B, L, D)
        """
        batch, seqlen, dim = hidden_states.shape

        # Input projection
        xz = self.in_proj(hidden_states)
        x, z = xz.chunk(2, dim=-1)  # each (B, L, ED)

        # Convolution (causal 1D conv)
        x = rearrange(x, "b l d -> b d l")
        x = self.conv1d(x)[:, :, :seqlen]  # causal trimming
        x = rearrange(x, "b d l -> b l d")
        x = F.silu(x)

        # SSM
        y = self.ssm(x)

        # Gating (element-wise multiply with z branch)
        y = y * F.silu(z)

        # Output projection
        output = self.out_proj(y)
        return output

    def ssm(self, x):
        """Selective SSM (S6) computation"""
        d_in, n = self.d_inner, self.d_state

        # A matrix (-exp(-A_log) ensures always negative = stability)
        A = -torch.exp(self.A_log.float())  # (ED, N)

        # x_proj: compute Delta, B, C
        x_dbl = self.x_proj(x)  # (B, L, dt_rank + 2N)
        delta, B, C = x_dbl.split(
            [self.dt_rank, n, n], dim=-1
        )

        # Delta: softplus ensures positive values
        delta = F.softplus(self.dt_proj(delta))  # (B, L, ED)

        # ZOH discretization and selective scan
        y = self.selective_scan(x, delta, A, B, C, self.D)
        return y

    def selective_scan(self, u, delta, A, B, C, D):
        """
        Selective scan algorithm
        In practice, uses CUDA kernels from mamba_ssm
        Here shown in pure PyTorch for conceptual understanding
        """
        b, l, d_in = u.shape
        n = A.shape[1]

        # Discretization
        # delta: (B, L, ED), A: (ED, N) → dA: (B, L, ED, N)
        deltaA = torch.exp(
            torch.einsum("bld,dn->bldn", delta, A)
        )
        # delta: (B, L, ED), B: (B, L, N), u: (B, L, ED)
        # → deltaB_u: (B, L, ED, N)
        deltaB_u = torch.einsum("bld,bln,bld->bldn", delta, B, u)

        # Recurrent scan
        x = torch.zeros(b, d_in, n, device=u.device, dtype=u.dtype)
        ys = []
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = torch.einsum("bdn,bn->bd", x, C[:, i, :])
            ys.append(y)
        y = torch.stack(ys, dim=1)  # (B, L, ED)

        # D (skip connection)
        y = y + u * D

        return y


class MambaModel(nn.Module):
    """Sequence model composed of stacked Mamba blocks"""
    def __init__(self, d_model, n_layers, d_state=16, expand=2):
        super().__init__()
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state=d_state, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        """x: (B, L, D)"""
        for layer in self.layers:
            x = x + layer(self.norm(x))  # Pre-norm residual
        return x

6. Mamba 2

In May 2024, Tri Dao and Albert Gu published Mamba 2, further strengthening the theoretical foundations.

Differences from Mamba 1

Mamba 1's selective scan required sequential computation at each timestep. Mamba 2 discovered State Space Duality (SSD) to make this even more efficient.

State Space Duality (SSD)

The core insight of Mamba 2: SSMs with a specific structure can be expressed as Semi-Separable Matrices, which are mathematically equivalent to a specific form of attention.

This mathematical duality enables:

  1. Expressing SSM computation as matrix multiplication form
  2. Leveraging highly optimized BLAS libraries
  3. Maximizing GPU efficiency with Tensor Core utilization
# SSD operation concept
# SSM recurrence: x[t] = A[t] * x[t-1] + B[t] * u[t]
#                 y[t] = C[t]^T * x[t]
#
# Process in chunks:
# - Within chunk: parallel computation via matrix multiplication
# - Between chunks: recurrence for state propagation

class Mamba2Block(nn.Module):
    """Mamba 2 Block"""
    def __init__(self, d_model, d_state=64, n_heads=8, chunk_size=64):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.n_heads = n_heads
        self.chunk_size = chunk_size
        self.d_head = d_model // n_heads

        # Multi-head structure
        self.norm = nn.LayerNorm(d_model)
        self.in_proj = nn.Linear(d_model, d_model * 2 + d_state * 2 + n_heads)
        self.out_proj = nn.Linear(d_model, d_model)

        # A parameters (per head)
        self.A_log = nn.Parameter(torch.randn(n_heads))

    def forward(self, x):
        """x: (B, L, D)"""
        # Actual implementation uses CUDA kernels from mamba_ssm package
        return x

Multi-Head Structure

Mamba 2 introduces multi-head SSM, making it structurally similar to Transformer's Multi-Head Attention. This provides:

  • Better expressiveness
  • Theoretical connection to attention
  • Easier hybrid architecture design

Integration Potential with Transformers

SSD theory shows that certain SSMs are mathematically equivalent to certain attention mechanisms. This provides the theoretical basis for hybrid architectures mixing Mamba and Transformers.


7. Hybrid Architectures

MambaFormer

MambaFormer interleaves Mamba blocks and Transformer blocks:

Layer 1: Mamba Block   (capture local patterns)
Layer 2: Attention     (capture global dependencies)
Layer 3: Mamba Block
Layer 4: Attention
...

This structure leverages the strengths of each component:

  • Mamba: efficient sequence processing, local patterns
  • Attention: selective information retrieval, global dependencies

Jamba (SSM + Transformer + MoE)

Jamba, released by AI21 Labs in 2024, is a more complex hybrid:

Jamba = Mamba + Transformer + Mixture of Experts

Architecture:
- 52B parameters (active: 12B)
- Layer ratio: Attention 1 : Mamba 7
- MoE applied to some layers
- Supports 256K context window

Compared to same-size Transformers, Jamba achieves:

  • 3x improvement in inference throughput
  • Dramatically improved memory efficiency for long contexts
# Jamba-style hybrid block (concept)
class JambaLayer(nn.Module):
    def __init__(self, d_model, layer_idx, attn_every_n=8):
        super().__init__()
        self.use_attention = (layer_idx % attn_every_n == 0)

        if self.use_attention:
            self.mixer = nn.MultiheadAttention(d_model, num_heads=8)
        else:
            self.mixer = MambaBlock(d_model)

        # MoE (only for some layers)
        self.use_moe = (layer_idx % 2 == 0)
        if self.use_moe:
            self.ffn = MixtureOfExperts(d_model, n_experts=16)
        else:
            self.ffn = nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.SiLU(),
                nn.Linear(d_model * 4, d_model)
            )

    def forward(self, x):
        x = x + self.mixer(x)
        x = x + self.ffn(x)
        return x

RWKV (Linear Attention)

RWKV is a hybrid of Transformer and RNN. Standing for "Receptance Weighted Key Value":

- Training: Parallelized like Transformer (matrix form)
- Inference: Recurrent like RNN (O(1) state)
- Token mixing without attention

Latest versions (RWKV-6, RWKV-7) show performance competitive with Mamba.

RetNet (Retentive Network)

Microsoft's RetNet aims to simultaneously achieve "Training Parallelism, Inference Efficiency, Competitive Performance":

Three computational paradigms:
1. Parallel: O(n^2) parallel computation during training (lower constant than Transformer)
2. Recurrent: O(1) memory during inference
3. Chunkwise Recurrent: balanced intermediate approach

8. Mamba Performance Comparison

Inference Speed (Linear Scale)

Mamba's greatest strength is its linear scaling with sequence length.

Normalizing performance at sequence length 2K to 1x:

Sequence Length   Transformer   Mamba
1K                0.5x          0.5x
2K                1.0x          1.0x
4K                3.5x          2.0x    ← 1.75x faster than Transformer
8K                13x           4.0x    ← 3.25x faster than Transformer
16K               50x           8.0x    ← 6.25x faster than Transformer
100K              ~1800x        ~50x    ← 36x faster than Transformer

Memory Efficiency

State size comparison during inference:

Model           Inference Memory (per 1K tokens)
Transformer     O(n) KV Cache
Mamba           O(1) SSM state (fixed size!)

Example: 130M parameter model, 1M token sequence
- Transformer: ~16GB KV Cache
- Mamba: ~1MB state (constant!)

Long Sequence Tasks

Long Range Arena (LRA) benchmark (sequence lengths 1K-16K):

Model        ListOps  Text  Retrieval  Image  Path-X  Avg
Transformer  36.4     65.0  57.5       42.4   0.0     40.3
LSTM         35.9     63.7  65.0       43.3   0.0     41.6
S4           59.6     86.8  90.9       88.7   86.1    82.4
Mamba        ~S4-level performance, faster processing

Mamba particularly excels over S4 on synthetic benchmarks like Selective Copying and Induction Heads, which are more relevant to real language modeling ability.


9. Applications of Mamba

Natural Language Processing

Mamba-based language models are emerging rapidly:

  • MambaChat: Conversational AI assistant
  • Falcon Mamba: Open-source 7B Mamba model released by TII
  • CodeMamba: Specialized for code generation

Mamba is efficient for long document processing, summarization, and translation compared to Transformers.

Bioinformatics

Genomic sequence analysis deals with very long sequences (millions of base pairs), making Mamba particularly advantageous:

  • Caduceus: Long DNA sequence modeling
  • Hyena: Long-sequence DNA/protein analysis
  • Potential applications in protein structure prediction
# Example of Mamba for biological sequence modeling
from mamba_ssm import MambaLMHeadModel
import torch

# DNA sequence modeling (A, T, G, C, N tokens)
DNA_VOCAB = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}
VOCAB_SIZE = len(DNA_VOCAB)

# Process very long sequences efficiently
model = MambaLMHeadModel.from_pretrained(
    "state-spaces/mamba-130m",
    device="cuda",
    dtype=torch.float16
)

Time Series Analysis

For financial, meteorological, IoT sensor data, and other long time series, Mamba shows its strengths:

  • TimeMamba: Long time series forecasting
  • MambaMixer: Multivariate time series modeling
  • Advancements in S4/S5-based time series models

Image Processing (VMamba)

VMamba extends Mamba to 2D image processing:

# VMamba core: 2D selective scan
# Scan image in 4 directions to capture 2D structure

# Directions:
# 1. Left→Right, Top→Bottom (standard raster scan)
# 2. Right→Left, Bottom→Top (reverse)
# 3. Top→Bottom, Left→Right (column-first)
# 4. Bottom→Top, Right→Left (reverse)

class VMambaBlock(nn.Module):
    """VMamba: Visual Mamba Block"""
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        # 4-directional SSMs
        self.ssms = nn.ModuleList([
            MambaBlock(d_model, d_state) for _ in range(4)
        ])
        self.out_proj = nn.Linear(d_model * 4, d_model)

    def forward(self, x):
        """x: (B, H, W, D) - image patch embeddings"""
        b, h, w, d = x.shape
        x_flat = x.view(b, h*w, d)

        outputs = []
        # 4-directional scan
        for i, ssm in enumerate(self.ssms):
            if i == 0:   # forward
                seq = x_flat
            elif i == 1: # reverse
                seq = x_flat.flip(1)
            elif i == 2: # column-first
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d)
            else:        # column-first reverse
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d).flip(1)

            out = ssm(seq)
            if i % 2 == 1:
                out = out.flip(1)
            outputs.append(out)

        # Combine 4 directional results
        combined = torch.cat(outputs, dim=-1)
        return self.out_proj(combined).view(b, h, w, d)

10. Practical Usage

Installing mamba-ssm Package

# Required packages
pip install torch torchvision torchaudio
pip install causal-conv1d>=1.2.0
pip install mamba-ssm

# Or install from source (latest features)
git clone https://github.com/state-spaces/mamba
cd mamba
pip install -e ".[dev]"

Using MambaLMHeadModel

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer

# Load pre-trained model
model_name = "state-spaces/mamba-2.8b-slimpj"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = MambaLMHeadModel.from_pretrained(
    model_name,
    device="cuda",
    dtype=torch.bfloat16
)
model.eval()

# Text generation
def generate_text(prompt, max_new_tokens=200, temperature=0.7):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated = output[0][input_ids.shape[1]:]
    return tokenizer.decode(generated, skip_special_tokens=True)

# Example
prompt = "State Space Models are powerful because"
result = generate_text(prompt)
print(result)

Fine-tuning Example

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
        }

def finetune_mamba(
    model_name="state-spaces/mamba-130m",
    texts=None,
    num_epochs=3,
    learning_rate=1e-4,
    batch_size=8,
):
    """Fine-tune Mamba model"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token = tokenizer.eos_token

    # Load model
    model = MambaLMHeadModel.from_pretrained(
        model_name,
        device=device,
        dtype=torch.bfloat16
    )

    # Dataset
    dataset = TextDataset(texts, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer (AdamW recommended for Mamba)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=0.1
    )

    # Scheduler
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=total_steps // 10,
        num_training_steps=total_steps
    )

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            input_ids = batch['input_ids'].to(device)

            # Language modeling: next token prediction
            outputs = model(input_ids)
            logits = outputs.logits

            # Shift for next token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()

            loss = nn.CrossEntropyLoss()(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}, Step {batch_idx}: Loss = {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} complete: Avg Loss = {avg_loss:.4f}")

    return model

Custom Mamba Model Configuration

from mamba_ssm import Mamba
from mamba_ssm.models.config_mamba import MambaConfig

# Custom configuration
config = MambaConfig(
    d_model=1024,
    n_layer=48,
    vocab_size=50280,
    d_state=16,
    d_conv=4,
    expand=2,
    dt_rank="auto",
    dt_min=0.001,
    dt_max=0.1,
    dt_init="random",
    dt_scale=1.0,
    dt_init_floor=1e-4,
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    pad_vocab_size_multiple=8,
)

# Create model (~1.4B parameters)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.bfloat16)
print(f"Parameter count: {sum(p.numel() for p in model.parameters()):,}")

Conclusion: The Future of Mamba

Mamba has brought revolutionary changes to the field of deep learning sequence modeling. Summarizing the key contributions:

  1. Selective Mechanism: SSM parameters that dynamically change based on input
  2. Hardware-Aware Design: Efficient computation leveraging GPU memory hierarchy
  3. Dual Representation: Parallelized during training, recurrent during inference
  4. Linear Complexity: Linear computation and memory complexity with sequence length

There are still challenges to address:

  • Somewhat weaker in-context learning compared to Transformers
  • Needs validation at very large scales
  • Clear demonstration of superiority over attention-based models

However, Mamba and SSM-family models will increasingly play an important role in areas where Transformers struggle — long sequence processing, real-time inference, and edge device deployment.

References