Skip to content
Published on

Mamba and State Space Model Paper Deep Dive: Transformer Alternative Architectures from Selective SSM to Mamba-2

Authors
  • Name
    Twitter

Mamba State Space Model

Entering

Since its introduction in 2017, the Transformer architecture has dominated almost all sequence modeling areas, including natural language processing, computer vision, audio, and time series. However, Self-Attention's quadratic time complexity O(L2)O(L^2) has a fundamental limitation that the amount of computation and memory usage increases explosively as the sequence length LL increases. This problem is becoming more serious in modern LLM, which handles context windows of 128K or more, and the KV Cache increasing in proportion to the sequence length during inference is also a practical bottleneck.

Against this background, the State Space Model (SSM) has been studied as an alternative that can process sequences with linear time complexity O(L)O(L). Structured State Spaces (S4), proposed by Albert Gu in 2021, opened a renaissance in SSM research by showing performance that surpasses Transformer in long sequence benchmarks. However, due to the limitations of S4 and other early SSMs as a Linear Time-Invariant (LTI) system, they lacked the ability to dynamically select information according to input content and fell short of Transformer in language modeling.

In December 2023, the Mamba paper published by Albert Gu and Tri Dao directly broke through this limitation. A selection mechanism that makes SSM parameters dependent on input was introduced, and efficiency was secured with a scan algorithm optimized for GPU hardware. Subsequently, Mamba-2, announced in May 2024, proved State Space Duality (SSD), which states that SSM and Attention belong to the same mathematical family, and presented an algorithm that is 2-8 times faster.

This article covers in-depth the mathematical basis of SSM (S4), Mamba's selective mechanism, Mamba-2's SSD framework, quantitative comparison with Transformer, application to various domains such as vision/audio/time series, and precautions during actual implementation and operation.

SSM basic theory (S4)

Continuous-time State Space Model

State Space Model is a continuous-time dynamic system derived from control theory. The process of converting the input signal x(t)x(t) to the output y(t)y(t) through the hidden state h(t)h(t) is described by the following differential equation.

h(t)=Ah(t)+Bx(t)h'(t) = Ah(t) + Bx(t)

y(t)=Ch(t)+Dx(t)y(t) = Ch(t) + Dx(t)

Here, ARN×NA \in \mathbb{R}^{N \times N} is the state transition matrix, BRN×1B \in \mathbb{R}^{N \times 1} is the input projection matrix, CR1×NC \in \mathbb{R}^{1 \times N} is the output projection matrix, and DRD \in \mathbb{R} is the skip connection term. NN is the state dimension and determines the memory capacity of the SSM.

Processing discrete sequences requires discretization of the continuous system. Applying Zero-Order Hold (ZOH) discretization to the step size Δ\Delta is as follows.

Aˉ=exp(ΔA)\bar{A} = \exp(\Delta A)

Bˉ=(ΔA)1(exp(ΔA)I)ΔB\bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B

The discretized SSM can be calculated in two ways. First, the sequence is processed sequentially in the form of recursion.

hk=Aˉhk1+Bˉxkh_k = \bar{A} h_{k-1} + \bar{B} x_k

yk=Chky_k = C h_k

Second, there is a form of convolution that processes the entire sequence at once. By pre-computing the kernel Kˉ\bar{K}, O(LlogL)O(L \log L) calculation using FFT is possible.

Kˉ=(CBˉ,CAˉBˉ,CAˉ2Bˉ,,CAˉL1Bˉ)\bar{K} = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \ldots, C\bar{A}^{L-1}\bar{B})

y=xKˉy = x * \bar{K}

Reset HiPPO on S4

The key contribution of S4 (Structured State Spaces for Sequence Modeling, Gu et al., 2021) is the initialization strategy of the state matrix AA. Randomly initialized AA fails to capture long-term dependencies and suffers from vanishing/exploding gradient problems. S4 utilizes the High-order Polynomial Projection Operators (HiPPO) framework to initialize AA into a special structure.

The HiPPO-LegS matrix is ​​designed to continuously approximate the past input with a Legendre polynomial basis. A key property of this structure is that it preserves information at all time scales equally.```python import torch import torch.nn as nn import numpy as np

def make_hippo_matrix(N: int) -> torch.Tensor: """HiPPO-LegS 행렬 생성.

이 행렬은 과거 입력을 르장드르 다항식 기저로 최적 근사하도록 설계되어 장기 의존성 포착에 유리하다.

Args: N: 상태 차원 (State dimension)

Returns: A: (N, N) HiPPO 행렬 """ P = np.sqrt(2 * np.arange(N) + 1) A = np.zeros((N, N)) for i in range(N): for j in range(N): if i > j: A[i, j] = P[i] * P[j] elif i == j: A[i, j] = i + 1

i < j: 0

A = -A # 안정성을 위해 음수 return torch.tensor(A, dtype=torch.float32)

class S4Layer(nn.Module): """S4(Structured State Spaces) 레이어의 핵심 구현.

HiPPO 초기화된 A 행렬과 합성곱 기반 병렬 계산을 결합한다. """

def init(self, d_model: int, state_dim: int = 64, seq_len: int = 1024): super().init() self.d_model = d_model self.state_dim = state_dim self.seq_len = seq_len

HiPPO 초기화

A = make_hippo_matrix(state_dim) self.A_log = nn.Parameter(torch.log(torch.clamp(-A.diagonal(), min=1e-4))) self.B = nn.Parameter(torch.randn(state_dim, 1) * 0.01) self.C = nn.Parameter(torch.randn(1, state_dim) * 0.01) self.log_dt = nn.Parameter(torch.rand(d_model).uniform_(-4, -1))

skip connection

self.D = nn.Parameter(torch.ones(d_model))

def _compute_kernel(self, L: int) -> torch.Tensor: """SSM 합성곱 커널을 사전 계산한다.""" dt = torch.exp(self.log_dt) # (d_model,) A = -torch.exp(self.A_log) # (state_dim,) 대각 성분

ZOH 이산화 (대각 A 가정)

dtA = dt.unsqueeze(-1) * A.unsqueeze(0) # (d_model, state_dim) A_bar = torch.exp(dtA)

커널 계산: K[k] = C @ A_bar^k @ B_bar

효율적인 Vandermonde 곱으로 O(N*L) 계산

powers = torch.arange(L, device=A.device).float()

A_bar^k = exp(k * dtA)

kernel = torch.einsum( 'dn,dn,nl->dl', self.C.squeeze(0).expand(self.d_model, -1), dt.unsqueeze(-1) * self.B.squeeze(-1).unsqueeze(0), torch.exp(dtA.unsqueeze(-1) * powers.unsqueeze(0).unsqueeze(0)) .squeeze() .T )

return kernel # (d_model, L)

def forward(self, x: torch.Tensor) -> torch.Tensor: """합성곱 모드의 forward pass.

Args: x: (batch, seq_len, d_model) 입력 시퀀스

Returns: y: (batch, seq_len, d_model) 출력 시퀀스 """ B, L, D = x.shape kernel = self._compute_kernel(L) # (d_model, L)

FFT 기반 합성곱

x_perm = x.permute(0, 2, 1) # (B, D, L) k_f = torch.fft.rfft(kernel, n=2 * L) # zero-pad x_f = torch.fft.rfft(x_perm, n=2 * L) y = torch.fft.irfft(x_f * k_f, n=2 * L)[..., :L]

skip connection

y = y + self.D.unsqueeze(0).unsqueeze(-1) * x_perm return y.permute(0, 2, 1)


## Optional mechanisms in Mamba

### Transition from LTI to Selection

Mamba's core insight is simple but powerful. The system is converted to time-varying by making the SSM parameters $B$, $C$, and $\Delta$ a function of the input. This allows the model to dynamically decide at each point in time what information to record in the state and what information to ignore.

In existing LTI SSM, all input tokens are processed with the same dynamics. In the sentence "The cat sat on the mat", the article "the" and the key noun "cat" are reflected in the state in the same way, so it is impossible to selectively highlight important information or filter out unnecessary information. Transformer's Attention essentially performs such selective information processing, and input-dependent parameters are essential for SSM to have comparable expressive power.

Expressed as a formula, Mamba's optional SSM is as follows.

$$B_t = \text{Linear}_B(x_t), \quad C_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))$$

$$\bar{A}_t = \exp(\Delta_t A), \quad \bar{B}_t = \Delta_t B_t$$

$$h_t = \bar{A}_t h_{t-1} + \bar{B}_t x_t$$

$$y_t = C_t h_t$$

The role of $\Delta_t$ is particularly important. When $\Delta_t$ is large, $\bar{A}_t = \exp(\Delta_t A)$ approaches 0, the previous state is reset, and the new input $x_t$ dominates the state. Conversely, if $\Delta_t$ is small, the previous state is preserved and the influence of the current input is reduced. This is similar to the gating mechanism of RNN, but is derived naturally within the mathematical framework of continuous-time systems.

### Synthesis tasks solved by Selection

The Mamba paper demonstrates the effectiveness of the selection mechanism in two synthetic tasks.

First, it is a Selective Copying task. This is a problem where only specific tokens must be selectively copied from the input sequence, and LTI SSM fails because it cannot determine “which token to copy” based on the input content. Mamba strongly records important token information in the state through $B_t$ and ignores unnecessary tokens.

Second, the Induction Heads task. When the pattern "A B ... A" appears, the problem of having to predict the next "B" is solved naturally by the Transformer with the Attention mechanism, but LTI SSM is unable to match this pattern. Mamba's selective mechanism can predict "B" by activating memory of the previous "A B" pattern when "A" appears.```python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelectiveSSM(nn.Module):
    """Mamba의 선택적 State Space Model 핵심 구현.

    B, C, delta를 입력의 함수로 만들어
    content-aware한 시퀀스 처리를 가능하게 한다.
    """

    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # A 행렬: 대각 구조, 로그 공간에서 파라미터화
        A = torch.arange(1, d_state + 1, dtype=torch.float32)
        self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(d_model, -1))

        # 입력 의존적 B, C 투영
        self.B_proj = nn.Linear(d_model, d_state, bias=False)
        self.C_proj = nn.Linear(d_model, d_state, bias=False)

        # delta (스텝 크기) 투영
        self.dt_proj = nn.Linear(d_model, d_model, bias=True)

        # skip connection
        self.D = nn.Parameter(torch.ones(d_model))

    def selective_scan(
        self,
        x: torch.Tensor,
        dt: torch.Tensor,
        B: torch.Tensor,
        C: torch.Tensor,
    ) -> torch.Tensor:
        """선택적 스캔 알고리즘 (순차 구현).

        실제 Mamba는 GPU SRAM 최적화된 CUDA 커널을 사용하지만,
        여기서는 알고리즘의 논리를 명확히 보여주기 위해 순차 구현한다.

        Args:
            x: (batch, seq_len, d_model) 입력
            dt: (batch, seq_len, d_model) 이산화 스텝 크기
            B: (batch, seq_len, d_state) 입력 투영
            C: (batch, seq_len, d_state) 출력 투영

        Returns:
            y: (batch, seq_len, d_model) 출력
        """
        batch, seq_len, d_model = x.shape
        d_state = B.shape[-1]

        # A 이산화: A_bar = exp(dt * A)
        A = -torch.exp(self.A_log)  # (d_model, d_state)
        dt_A = torch.einsum('bld,dn->bldn', dt, A)  # (B, L, D, N)
        A_bar = torch.exp(dt_A)

        # B 이산화: B_bar = dt * B
        dt_B = torch.einsum('bld,bln->bldn', dt, B)  # (B, L, D, N)

        # 순차 스캔
        h = torch.zeros(batch, d_model, d_state, device=x.device)
        outputs = []

        for t in range(seq_len):
            # h_t = A_bar_t * h_{t-1} + B_bar_t * x_t
            h = A_bar[:, t] * h + dt_B[:, t] * x[:, t].unsqueeze(-1)
            # y_t = C_t @ h_t
            y_t = torch.einsum('bdn,bn->bd', h, C[:, t])
            outputs.append(y_t)

        y = torch.stack(outputs, dim=1)  # (B, L, D)
        return y

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """선택적 SSM의 forward pass.

        Args:
            x: (batch, seq_len, d_model)

        Returns:
            y: (batch, seq_len, d_model)
        """
        # 입력 의존적 파라미터 계산
        B = self.B_proj(x)                          # (B, L, N)
        C = self.C_proj(x)                          # (B, L, N)
        dt = F.softplus(self.dt_proj(x))            # (B, L, D), 양수 보장

        # 선택적 스캔
        y = self.selective_scan(x, dt, B, C)

        # skip connection
        y = y + self.D.unsqueeze(0).unsqueeze(0) * x

        return y
```### Hardware-Aware scanning algorithm

When the Selection Mechanism is introduced, SSM is no longer LTI, so the FFT-based convolution used in S4 cannot be applied. Since time-varying parameters change at every time, it is impossible to pre-compute the convolution kernel. Sequential scan must be used, but a simple for loop implementation is very slow as it does not take advantage of the GPU's parallelism.

Mamba solves this problem with a hardware-aware algorithm inspired by FlashAttention. There are three core ideas. First, kernel fusion is applied to minimize data movement between GPU HBM (High Bandwidth Memory) and SRAM (on-chip memory). Discretization, scanning, and output projection are performed in one CUDA kernel, keeping intermediate results in SRAM. Second, the intermediate state tensors ($\bar{A}, \bar{B}$, etc.) are recalculated in real time in SRAM rather than recorded in HBM. This ensures that memory usage is not proportional to the sequence length. Third, even in backpropagation, memory is saved by recomputing intermediate states rather than storing them.

## Mamba architecture details

### Mamba block structure

Mamba blocks have a different structure from Transformer blocks. Transformer consists of two sublayers, Self-Attention and FFN, but Mamba integrates these two roles into one block. Specifically, the input is branched into two branches, one branch goes through convolution and selective SSM, and the other branch plays a gating role.```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class MambaConfig:
    d_model: int = 768
    d_state: int = 16
    d_conv: int = 4
    expand: int = 2
    n_layers: int = 24
    vocab_size: int = 50257
    dropout: float = 0.0


class MambaBlock(nn.Module):
    """Mamba 블록: 선택적 SSM, 합성곱, 게이팅을 통합한 구조.

    Transformer의 Attention + FFN을 하나의 블록으로 대체한다.
    확장 비율(expand)은 내부 차원을 결정하며,
    기본값 2는 d_inner = 2 * d_model을 의미한다.
    """

    def __init__(self, config: MambaConfig):
        super().__init__()
        self.d_model = config.d_model
        self.d_state = config.d_state
        self.d_conv = config.d_conv
        d_inner = config.d_model * config.expand

        # 입력 투영: 두 갈래(z, x)로 분기
        self.in_proj = nn.Linear(config.d_model, d_inner * 2, bias=False)

        # 1D 합성곱: 지역적 문맥 포착
        self.conv1d = nn.Conv1d(
            in_channels=d_inner,
            out_channels=d_inner,
            kernel_size=config.d_conv,
            padding=config.d_conv - 1,
            groups=d_inner,  # depthwise convolution
            bias=True,
        )

        # 선택적 SSM 파라미터
        self.B_proj = nn.Linear(d_inner, config.d_state, bias=False)
        self.C_proj = nn.Linear(d_inner, config.d_state, bias=False)
        self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)

        # A 행렬: 대각 구조
        A = torch.arange(1, config.d_state + 1, dtype=torch.float32)
        self.A_log = nn.Parameter(
            torch.log(A).unsqueeze(0).expand(d_inner, -1).clone()
        )
        self.D = nn.Parameter(torch.ones(d_inner))

        # 출력 투영
        self.out_proj = nn.Linear(d_inner, config.d_model, bias=False)
        self.norm = nn.RMSNorm(config.d_model)

    def _selective_scan(self, x, dt, B, C):
        """선택적 스캔 (순차 구현). 실제로는 CUDA 커널 사용."""
        batch, seq_len, d_inner = x.shape
        d_state = B.shape[-1]

        A = -torch.exp(self.A_log)
        dt_A = torch.einsum('bld,dn->bldn', dt, A)
        A_bar = torch.exp(dt_A)
        dt_B_x = torch.einsum('bld,bln->bldn', dt * x, B)

        h = torch.zeros(batch, d_inner, d_state, device=x.device)
        outputs = []
        for t in range(seq_len):
            h = A_bar[:, t] * h + dt_B_x[:, t]
            y_t = torch.einsum('bdn,bn->bd', h, C[:, t])
            outputs.append(y_t)

        y = torch.stack(outputs, dim=1)
        return y

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Mamba 블록의 forward pass.

        Args:
            x: (batch, seq_len, d_model)

        Returns:
            output: (batch, seq_len, d_model)
        """
        residual = x
        x = self.norm(x)

        # 두 갈래로 분기
        xz = self.in_proj(x)
        x_branch, z = xz.chunk(2, dim=-1)

        # 합성곱 경로: 지역적 패턴 포착
        x_conv = x_branch.transpose(1, 2)  # (B, D_inner, L)
        x_conv = self.conv1d(x_conv)[:, :, :x.shape[1]]  # causal trim
        x_conv = x_conv.transpose(1, 2)    # (B, L, D_inner)
        x_conv = F.silu(x_conv)

        # 선택적 SSM
        B = self.B_proj(x_conv)
        C = self.C_proj(x_conv)
        dt = F.softplus(self.dt_proj(x_conv))

        y = self._selective_scan(x_conv, dt, B, C)
        y = y + self.D.unsqueeze(0).unsqueeze(0) * x_conv

        # 게이팅: z 갈래와 원소곱
        y = y * F.silu(z)

        # 출력 투영 + 잔차 연결
        output = self.out_proj(y)
        return output + residual
```### Parameter efficiency and architecture comparison

Mamba's architecture has several structural differences compared to Transformer. Transformer has two large sublayers, Multi-Head Attention and FFN, each with separate Layer Normalization and residual connection. Mamba integrates these into one block, capturing local patterns with convolution, modeling global dependencies with optional SSM, and controlling information flow with gating.

While Transformer's Attention is a "global attention" method that explicitly computes the relationships between all token pairs, Mamba's selective SSM is a "compressed memory" method that accumulates information into a compressed state vector. This difference is the fundamental cause of the difference in complexity over sequence length.

Based on the same number of parameters (about 2.8B), Mamba uses about 1.5 times fewer FLOPs than Transformer. This is because Attention's $O(L^2)$ operation disappears and is replaced by an optional SSM with a similar computational amount as FFN.

## Mamba-2 and SSD Framework

### Discovery of State Space Duality

The most important theoretical contribution of Mamba-2 (Dao & Gu, 2024) is the discovery of State Space Duality (SSD). Although SSM and Attention appear to be very different operations on the surface, it has been proven that mathematically they belong to the same structured matrix family.

Specifically, the output of the selective SSM can be expressed as the following matrix-vector product.

$$y = Mx$$

Here $M$ is a semi-separable matrix. The $(i, j)$ elements of this matrix are as follows.

$$M_{ij} = \begin{cases} C_i^\top \bar{A}_{i:j} B_j & \text{if } i \geq j \\ 0 & \text{if } i < j \end{cases}$$

Here, $\bar{A}_{i:j} = \bar{A}_i \bar{A}_{i-1} \cdots \bar{A}_{j+1}$ means the cumulative transition from time point $j$ to $i$. This matrix has a lower triangular structure, which is the same structure as the masked attention matrix of Causal Attention.

If softmax is removed from the attention matrix $\text{softmax}(QK^\top / \sqrt{d})$, it becomes $QK^\top$, which is also a rank-$d$ semi-separable matrix. SSD formalizes this connection, showing that SSM and Linear Attention share the same matrix algebraic structure.

### SSD algorithm and chunk division

The practical significance of SSD is that it enables the design of new algorithms. Mamba-2 uses a hybrid algorithm that divides the sequence into fixed-size chunks, performs parallel processing within the chunks using matrix multiplication (attention-like operations), and transfers state between chunks using SSM recursion.```python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SSDBlock(nn.Module):
    """Mamba-2의 Structured State Space Duality(SSD) 블록.

    시퀀스를 청크로 나누어 청크 내부는 행렬 곱(병렬),
    청크 간에는 SSM 재귀로 처리하는 하이브리드 알고리즘.
    """

    def __init__(
        self,
        d_model: int = 768,
        d_state: int = 128,
        n_heads: int = 8,
        chunk_size: int = 256,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.n_heads = n_heads
        self.chunk_size = chunk_size
        self.head_dim = d_model // n_heads

        # 멀티헤드 구조: Q(C), K(B), V(x) 투영
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_state * n_heads, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.dt_proj = nn.Linear(d_model, n_heads, bias=True)

        # A 행렬 (헤드별 스칼라)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, n_heads + 1, dtype=torch.float32)))

        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.norm = nn.RMSNorm(d_model)

    def _chunk_scan(
        self,
        Q: torch.Tensor,
        K: torch.Tensor,
        V: torch.Tensor,
        dt: torch.Tensor,
    ) -> torch.Tensor:
        """청크 기반 SSD 스캔.

        Args:
            Q: (batch, n_heads, seq_len, head_dim)  -- C 역할
            K: (batch, n_heads, seq_len, d_state)   -- B 역할
            V: (batch, n_heads, seq_len, head_dim)  -- x 역할
            dt: (batch, n_heads, seq_len)            -- delta 역할

        Returns:
            y: (batch, n_heads, seq_len, head_dim)
        """
        B, H, L, D = Q.shape
        C = self.chunk_size

        # 시퀀스를 청크로 분할
        n_chunks = (L + C - 1) // C
        pad_len = n_chunks * C - L
        if pad_len > 0:
            Q = F.pad(Q, (0, 0, 0, pad_len))
            K = F.pad(K, (0, 0, 0, pad_len))
            V = F.pad(V, (0, 0, 0, pad_len))
            dt = F.pad(dt, (0, pad_len))

        # 청크 단위로 reshape
        Q = Q.view(B, H, n_chunks, C, D)
        K = K.view(B, H, n_chunks, C, -1)
        V = V.view(B, H, n_chunks, C, D)
        dt = dt.view(B, H, n_chunks, C)

        # 청크 내부: attention-like 행렬 곱 (병렬 처리 가능)
        # 청크 내부에서 Q와 K의 유사도 행렬을 계산
        # decay를 적용한 causal mask 생성
        A = -torch.exp(self.A_log)  # (n_heads,)
        # 누적 감쇠율 계산
        dt_cumsum = dt.cumsum(dim=-1)  # (B, H, n_chunks, C)

        # 청크 내부 attention 행렬
        # M[i,j] = exp(A * (dt_cumsum[i] - dt_cumsum[j])) for i >= j
        decay_diff = dt_cumsum.unsqueeze(-1) - dt_cumsum.unsqueeze(-2)  # (B,H,nc,C,C)
        decay = torch.exp(A.view(1, H, 1, 1, 1) * decay_diff)
        causal_mask = torch.tril(torch.ones(C, C, device=Q.device))
        decay = decay * causal_mask

        # 청크 내부 출력 계산
        attn = torch.einsum('bhncqd,bhnckd->bhncqk', Q, K) * decay
        y_intra = torch.einsum('bhncqk,bhnckd->bhncqd', attn, V)

        # 최종 reshape
        y = y_intra.view(B, H, n_chunks * C, D)
        return y[:, :, :L]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """SSD 블록의 forward pass."""
        residual = x
        x = self.norm(x)
        B, L, D = x.shape

        # 멀티헤드 투영
        Q = self.q_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(B, L, self.n_heads, self.d_state).transpose(1, 2)
        V = self.v_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        dt = F.softplus(self.dt_proj(x)).transpose(1, 2)  # (B, H, L)

        # SSD 스캔
        y = self._chunk_scan(Q, K, V, dt)

        # 출력 투영
        y = y.transpose(1, 2).contiguous().view(B, L, D)
        return self.out_proj(y) + residual
```### Structural improvements to Mamba-2

Mamba-2 introduced the following structural changes compared to Mamba-1.

First, the introduction of a multi-head structure. Mamba-1 applies independent SSM for each channel, but Mamba-2 shares parameters using a head structure similar to Attention. This allows more state dimensions to be utilized with the same amount of computation.

Second, there is simplification of the A matrix. Mamba-1 used a diagonal A matrix for each channel, but Mamba-2 further simplified it to a scalar value for each head. This simplification is the key condition that makes matrix decomposition of SSD possible.

Third, it is a chunk-based parallel algorithm. The sequence is divided into fixed-size chunks, and the chunks are processed through matrix multiplication using Tensor Core, and the state is passed between chunks through SSM recursion. This method is 2-8 times faster than pure recursion and takes full advantage of the GPU's parallel hardware.

## Performance comparison compared to Transformer

### Quantitative Benchmarks

We summarize the results of comparing the performance of various sequence modeling architectures. Measurement conditions used approximately 2.8B parameter scale and the same learning data (The Pile, 300B tokens).

| Item | Transformer (2.8B) | Mamba (2.8B) | Mamba-2 (2.8B) | RWKV-6 (3B) | RetNet (2.7B) |
| --------------------- | ------------------ | ------------ | -------------- | ----------- | ------------- |
| Pile PPL (val) | 6.73 | 6.22 | 6.10 | 6.85 | 7.02 |
| Training throughput (tok/s) | Baseline (1.0x) | 1.2x | 1.8x | 1.1x | 1.3x |
| Inference Delay (1K Tokens) | 38ms/tok | 12ms/tok | 10ms/tok | 14ms/tok | 16ms/tok |
| Inference Delay (16K tokens) | 180ms/tok | 12ms/tok | 10ms/tok | 14ms/tok | 16ms/tok |
| Inference Memory (1K tokens) | 2.1 GB | 0.8 GB | 0.7 GB | 0.9 GB | 1.0 GB |
| Inference Memory (16K) | 8.5 GB | 0.8 GB | 0.7 GB | 0.9 GB | 1.0 GB |
| Sequence Length Extension | KV Cache linear increase | constant | constant | constant | constant |
| In-Context Learning | strong | middle | Medium-strong | Weak | Weak |
| Attention pattern analysis | possible (visualization) | Not possible | partial | Not possible | partial |

There are several points to note in this table. First, in the Mamba series, inference delay and memory remain constant even as sequence length increases. Transformer shows a delay increase of about 4.7 times at 16K tokens compared to 1K, but Mamba is the same. Second, Mamba-2 is 1.5 times faster than Mamba-1 in learning throughput. This is the chunk-based parallelization effect of the SSD algorithm. Third, Transformer still shows superiority in In-Context Learning. This is one of the key weaknesses of SSM-based models.

### Scaling by sequence length

The advantages of SSM become more evident when analyzing performance changes according to sequence length. Transformer's Self-Attention is $O(L^2)$, so if the sequence length is doubled, the amount of computation increases by four times. Applying optimizations such as FlashAttention-2 does not change this fundamental complexity. On the other hand, Mamba's optional SSM is $O(L)$, so the amount of computation per token is constant regardless of the sequence length.

The real wall appears around the 32K-64K token range. In this section, based on the A100 80GB GPU, Transformer must reduce the batch size due to KV Cache, and Mamba can maintain the same batch size. Transformer is practically impossible for extremely long sequences of 1M tokens or more, but SSM can handle them in principle.

## Apply to various domains

### Vision: Vision Mamba and VMamba

Research on applying SSM to vision tasks is active. Vision Mamba (Vim, Zhu et al., 2024) replaces ViT (Vision Transformer)'s Self-Attention with bidirectional SSM. By scanning image patches forward and backward to capture bidirectional context, we achieved equivalent accuracy to DeiT in ImageNet classification with 2.8 times less GPU memory.VMamba (Liu et al., 2024) proposed the Cross-Scan Module (CSM) to better reflect the spatial structure of 2D images. By scanning the image in four directions (top left - bottom right, top right - bottom left, bottom left - top right, bottom right - top left), 2D spatial information is captured in a 1D sequence model.

The key advantage of SSM in the vision field is high-resolution image processing. ViT has a computation amount proportional to the square of the number of patches (sequence length), so increasing the resolution from 224x224 to 512x512 increases the computation amount by about 5 times. Because of Vision Mamba's linear complexity, this cost increase is only about 2X.

### Audio and voice

Audio signals are essentially long sequences. At 16kHz sampling, one minute of audio has 960,000 viewpoints, and it is unrealistic to process this directly with a transformer. Audio Mamba (Erol et al., 2024) replaced Audio Spectrogram Transformer (AST) in audio classification tasks, reducing memory usage by 70% at equivalent accuracy.

Application of Mamba is also being attempted in the field of speech synthesis (TTS). Since the transformer's KV Cache bottleneck is serious in conditional generation of long audio sequences, SSM-based decoder is more suitable for real-time speech synthesis.

### Time series forecasting

Time series data is the domain where SSM is most naturally applied. This is because the mathematical framework of SSM, which models continuous-time dynamic systems, fits well with the continuous nature of time series. S-Mamba (Wang et al., 2024) shows equal or better performance than PatchTST in multivariate time series prediction, and the inference speed is 3 times faster.

A special advantage of SSM in time series forecasting is its handling of irregularly-sampled time series. Continuous-time SSM can adjust the discretization step $\Delta$ to suit the observation interval, so data with irregular time intervals can be directly processed without interpolation. This is a huge advantage for medical data, IoT sensor data, etc.

### Genomics and proteins

DNA sequences are extremely long sequences (millions to billions of bp), and long-distance interactions are biologically important. Caduceus (Schiff et al., 2024) is a bidirectional Mamba-based DNA language model that reflects the double-stranded structure and reverse complementarity of DNA in its architecture. It showed excellent performance in predicting long-distance mutation effects compared to existing DNA Transformer models (Nucleotide Transformer, HyenaDNA).

## Limitations and challenges

### Weaknesses of In-Context Learning

The most serious limitation of SSM-based models, including Mamba, is their In-Context Learning (ICL) capability. Transformer has excellent ICL capabilities to perform new tasks by referring to the few-shot examples provided in the prompt. This is because Self-Attention can directly reference (direct lookup) any position in the sequence.

SSM, on the other hand, packs information into a fixed-size state vector, making it difficult to accurately "search" for a specific example of a prompt. When the state dimension $N$ is much smaller than the sequence length ($N \ll L$), information loss is inevitable. Mamba-2 partially alleviates this problem by greatly increasing the state dimensionality (N=128 or more), but falls short of Transformer's explicit token-level referencing capabilities.

### Limitations of the Retrieval task

In Information Retrieval, that is, the task of accurately extracting specific information from a sequence, SSM shows a distinct weakness compared to Transformer. For example, in the "passkey retrieval" task of finding a hidden secret key in a haystack of 100,000 tokens, Mamba performs insecurely. This is because SSM's compressed memory method is suitable for statistical summarization rather than precise information preservation.

### The rise of hybrid architecture

Recognizing these limitations, recent research is exploring hybrid architectures that combine SSM and Attention. Jamba (AI21 Labs, 2024) simultaneously secures the efficiency of SSM and the ICL capability of Attention by intersecting the Mamba layer and Transformer layer. It supports 256K contexts in 52B parameters, and the inference throughput is 3 times higher than that of a pure Transformer.

Zamba (Zyphra, 2024) and Griffin (De et al., 2024) are also adopting a similar hybrid approach, which is becoming a more practical compromise than pure SSM or pure Transformer.

## Practical implementation guide

### Utilize the official Mamba library

The official implementation of Mamba is`state-spaces/mamba`It is provided in the repository. A GPU environment is essential because the CUDA kernel is included.```bash
# Mamba 설치 및 기본 사용법

# 1. 설치 (CUDA 11.8 이상 필요)
pip install mamba-ssm

# 2. causal-conv1d 의존성 설치
pip install causal-conv1d>=1.2.0

# 3. 사전학습된 모델 사용 (Hugging Face)
pip install transformers

# 4. 사전학습된 Mamba 모델로 텍스트 생성
python -c "
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch

# Mamba-2.8B 모델 로드
model = MambaLMHeadModel.from_pretrained(
    'state-spaces/mamba-2.8b',
    device='cuda',
    dtype=torch.float16,
)
model.eval()

# 토크나이저 (GPT-NeoX 토크나이저 호환)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

# 텍스트 생성
prompt = 'State Space Models are'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
output = model.generate(
    input_ids=input_ids,
    max_length=100,
    temperature=0.7,
    top_p=0.9,
)
print(tokenizer.decode(output[0]))
"
```### Benchmark script

This is a benchmark script that quantitatively compares the inference performance of Mamba and Transformer.```python
import torch
import time
from dataclasses import dataclass

@dataclass
class BenchmarkResult:
    model_name: str
    seq_len: int
    batch_size: int
    latency_ms: float
    memory_mb: float
    tokens_per_sec: float


def benchmark_inference(
    model: torch.nn.Module,
    model_name: str,
    seq_lengths: list[int],
    batch_size: int = 1,
    n_warmup: int = 5,
    n_measure: int = 20,
    device: str = "cuda",
) -> list[BenchmarkResult]:
    """다양한 시퀀스 길이에서 추론 성능을 측정한다.

    Args:
        model: 벤치마크 대상 모델
        model_name: 모델 이름 (결과 표시용)
        seq_lengths: 테스트할 시퀀스 길이 목록
        batch_size: 배치 크기
        n_warmup: 워밍업 반복 수
        n_measure: 측정 반복 수
        device: 실행 디바이스

    Returns:
        시퀀스 길이별 벤치마크 결과 목록
    """
    model = model.to(device).eval()
    results = []

    for seq_len in seq_lengths:
        input_ids = torch.randint(
            0, 50257, (batch_size, seq_len), device=device
        )

        # GPU 캐시 정리
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # 워밍업
        with torch.no_grad():
            for _ in range(n_warmup):
                _ = model(input_ids)
        torch.cuda.synchronize()

        # 측정
        latencies = []
        with torch.no_grad():
            for _ in range(n_measure):
                torch.cuda.synchronize()
                start = time.perf_counter()
                _ = model(input_ids)
                torch.cuda.synchronize()
                end = time.perf_counter()
                latencies.append((end - start) * 1000)

        avg_latency = sum(latencies) / len(latencies)
        peak_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
        tokens_per_sec = (batch_size * seq_len) / (avg_latency / 1000)

        results.append(BenchmarkResult(
            model_name=model_name,
            seq_len=seq_len,
            batch_size=batch_size,
            latency_ms=avg_latency,
            memory_mb=peak_memory,
            tokens_per_sec=tokens_per_sec,
        ))

        print(
            f"[{model_name}] seq_len={seq_len:>6d} | "
            f"latency={avg_latency:>8.2f}ms | "
            f"memory={peak_memory:>8.1f}MB | "
            f"throughput={tokens_per_sec:>10.0f} tok/s"
        )

    return results


def compare_models(results_dict: dict[str, list[BenchmarkResult]]):
    """모델 간 성능 비교 테이블을 출력한다."""
    print("\n" + "=" * 80)
    print(f"{'모델':<20} {'시퀀스 길이':>10} {'지연(ms)':>10} "
          f"{'메모리(MB)':>12} {'처리량(tok/s)':>15}")
    print("=" * 80)

    for model_name, results in results_dict.items():
        for r in results:
            print(
                f"{r.model_name:<20} {r.seq_len:>10d} "
                f"{r.latency_ms:>10.2f} {r.memory_mb:>12.1f} "
                f"{r.tokens_per_sec:>15.0f}"
            )
        print("-" * 80)
```## Precautions during operation

### CUDA kernel compatibility

Mamba's official CUDA kernel is optimized for specific GPU architectures (Ampere and above, sm_80+). In V100 (sm_70) or previous generation GPUs, a fallback implementation is used, which significantly reduces performance. Before production deployment, you must check the compute capability of the target GPU and use Ampere (A100/A10G) or Hopper (H100) GPU.

You should also pay attention to version compatibility of the causal-conv1d package. If the mamba-ssm version and causal-conv1d version do not match, a runtime error occurs. Be sure to check the compatibility table in the official README.

### State initialization and sequence boundaries

The most common mistake in inference in SSM-based models is mishandling state initialization. In batch inference, mixing states from different sequences pollutes the output. At the beginning of each sequence, the state must be explicitly initialized to a zero vector, and in streaming inference, the state must also be reset when the conversation turns.

In a multi-turn conversation, whether to maintain or reset the state of the previous turn is a design decision. Maintaining the state preserves the context of previous conversations, but because the state's ability to preserve information is more limited than Transformer's KV Cache, information loss may accumulate in long conversations.

### Precision and stability during training

Mamba's optional SSM is sensitive to the numerical stability of the $\Delta$ parameter using softplus activation. If $\Delta$ is too large, $\exp(\Delta A)$ converges to 0 and the gradient is lost, and if it is too small, state updates are insignificant and learning stagnates. It is recommended to initialize $\Delta$ to $\text{Uniform}(-4, -1)$ at the beginning of training to ensure appropriate coverage in log space.

BF16 learning is generally stable, but numerical instability may occasionally occur in the exponential operation of the state matrix. By parameterizing A in log space ($A = -\exp(A_{\log})$), negativity and stability can be guaranteed at the same time.

### Serving framework support status

vLLM partially supports serving the Mamba model. PagedAttention is for the Attention-based model and therefore does not apply to the SSM model, but it can utilize vLLM's scheduling and batching infrastructure. Optimization of the Mamba-1 model is supported in TensorRT-LLM, and Mamba-2 requires a separate plugin.

For Triton serving, it is possible to convert and deploy the Mamba model to ONNX, but there is a performance loss because the CUDA custom kernel is not included in the ONNX graph. Currently, a custom serving server that directly wraps the mamba-ssm library is the most reliable option.

### Debugging Checklist

Summarizes problems that may occur in production and how to respond.

- **When the output falls into a repetitive loop**: The state is stuck in a specific pattern. Add logic to increase the temperature or partially reset the state.
- **When quality deteriorates rapidly in long inputs**: The state dimension is insufficient compared to the sequence length. When selecting a model, use a model with sufficiently large d_state (N=64 or more).
- **When the results for each sequence are different in batch inference**: It is highly likely that information from the padding token was introduced into the state due to incorrect padding processing. Block state updates by forcing $\Delta$ for the padding token to 0.
- **When CUDA OOM occurs intermittently during training**: The recomputation strategy may not have been applied correctly. Enable gradient checkpointing and reduce batch size.

## References

1. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. _arXiv preprint arXiv:2312.00752_. [https://arxiv.org/abs/2312.00752](https://arxiv.org/abs/2312.00752)

2. Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. _arXiv preprint arXiv:2405.21060_. [https://arxiv.org/abs/2405.21060](https://arxiv.org/abs/2405.21060)

3. Mamba official GitHub repository. [https://github.com/state-spaces/mamba](https://github.com/state-spaces/mamba)

4. Gu, A., Goel, K., & Re, C. (2021). Efficiently Modeling Long Sequences with Structured State Spaces (S4). _ICLR 2022_. [https://arxiv.org/abs/2111.00396](https://arxiv.org/abs/2111.00396)5. Patro, B. N., & Agneeswaran, V. S. (2024). Mamba-360: Survey of State Space Models as Transformer Alternative for Long Sequence Modeling. _arXiv preprint arXiv:2404.16112_. [https://arxiv.org/abs/2404.16112](https://arxiv.org/abs/2404.16112)

6. Lieber, O. et al. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. _arXiv preprint arXiv:2403.19887_. [https://arxiv.org/abs/2403.19887](https://arxiv.org/abs/2403.19887)

7. Zhu, L. et al. (2024). Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model. _ICML 2024_. [https://arxiv.org/abs/2401.09417](https://arxiv.org/abs/2401.09417)

8. Schiff, Y. et al. (2024). Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling. _ICML 2024_. [https://arxiv.org/abs/2403.03234](https://arxiv.org/abs/2403.03234)