Skip to content
Published on

Mamba: Linear-Time Sequence Modeling with Selective State Spaces — Paper Analysis

Authors
  • Name
    Twitter
Mamba SSM

Paper Overview

  • Title: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
  • Authors: Albert Gu, Tri Dao (Carnegie Mellon University, Princeton University)
  • Published: December 2023 (arXiv: 2312.00752), Accepted at ICML 2024
  • Code: github.com/state-spaces/mamba

Motivation: The Limitations of Transformers

Transformers dominate sequence modeling, but they suffer from fundamental problems:

  • O(N²) Complexity: Self-attention has quadratic complexity with respect to sequence length
  • Long Sequence Bottleneck: Memory and compute explode beyond 128K context length
  • Inference Inefficiency: KV Cache grows proportionally with sequence length

State Space Models (SSMs) theoretically achieve O(N) complexity, but prior SSMs (S4, H3, etc.) fell short of Transformers in language modeling due to their fixed, input-independent parameters.

Core Idea: Selective State Spaces

The Problem with Existing SSMs

Traditional SSMs are Linear Time-Invariant (LTI) systems:

h(t)=Ah(t)+Bx(t)h'(t) = Ah(t) + Bx(t) y(t)=Ch(t)y(t) = Ch(t)

Here, A, B, and C are fixed matrices. Whether the input is "hello" or "world," the system processes it the same way. This means content-aware reasoning is impossible.

Mamba's Solution: Selection Mechanism

Mamba's key innovation is making B, C, and Δ input-dependent:

# Traditional SSM: fixed parameters
B = nn.Parameter(torch.randn(N))  # Learned but input-independent

# Mamba: input-dependent parameters
B = nn.Linear(D, N)(x)      # B varies depending on x
C = nn.Linear(D, N)(x)      # C varies depending on x
delta = nn.Linear(D, D)(x)  # Step size also varies depending on x

This is what Selective State Space means. Depending on the input, the model selects which information to store in its state and which to ignore.

Intuitive Understanding

The Selection Mechanism can be understood intuitively as:

  • Δ (delta): "How much attention should this token receive?" — A large Δ strongly reflects the current input, while a small Δ retains the previous state
  • B: "Which parts of the input should be recorded in the state?"
  • C: "Which parts of the state should be extracted as output?"

This is similar to what Transformer attention does, but at O(N) complexity.

Hardware Optimization: The Selective Scan Algorithm

Since the Selective SSM is input-dependent, it cannot leverage the efficient convolution tricks used by traditional SSMs. Mamba addresses this through a hardware-aware algorithm.

The Problem: HBM ↔ SRAM Bottleneck

GPU memory hierarchy:

┌────────────────────┐
SRAM (20MB)      │  ← Very fast, very small
├────────────────────┤
HBM (80GB)       │  ← Slow, large
└────────────────────┘

A naive implementation repeatedly stores and loads intermediate states to/from HBM, becoming memory-bound.

Solution: Kernel Fusion + Recomputation

Mamba's Selective Scan algorithm:

  1. Compute entirely in SRAM: Fuse discretization, selective scan, and output computation into a single kernel
  2. No intermediate state storage: Store only (N, B, C, Δ) in HBM during the forward pass; recompute intermediate states during the backward pass
  3. Result: IO-aware optimization similar to FlashAttention
# Simplified Selective Scan (conceptual code)
def selective_scan(x, A, B, C, delta):
    """
    x: (batch, length, d_model)
    Returns: (batch, length, d_model)
    """
    batch, length, d = x.shape
    n = A.shape[1]  # state dimension

    # Discretize A, B using delta
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
    deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)

    # Sequential scan (parallelized via CUDA kernel in practice)
    h = torch.zeros(batch, d, n, device=x.device)
    ys = []
    for i in range(length):
        h = deltaA[:, i] * h + deltaB_x[:, i]
        y = (h * C[:, i].unsqueeze(1)).sum(-1)
        ys.append(y)

    return torch.stack(ys, dim=1)

Mamba Block Architecture

Mamba draws inspiration from the H3 architecture and uses a simplified block design:

Input (x)
    ├──── Linear projection (expand) ────┐
    │                                     │
    ▼                                     ▼
  Conv1D                              SiLU (gate)
    │                                     │
    ▼                                     │
  SiLU    │                                     │
    ▼                                     │
  SSM (selective scan)    │                                     │
    ▼                                     │
    ×─────────────────────────────────────┘
  Linear projection (reduce)
  Output

Implemented in PyTorch:

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        d_inner = d_model * expand

        # Input projection
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        # Conv1D
        self.conv1d = nn.Conv1d(
            d_inner, d_inner, d_conv,
            padding=d_conv - 1, groups=d_inner
        )

        # SSM parameters
        self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False)  # B, C, delta
        self.dt_proj = nn.Linear(1, d_inner, bias=True)

        # A parameter (structured)
        A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))

        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        batch, length, _ = x.shape

        # Dual path
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Conv + activation
        x = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)
        x = F.silu(x)

        # SSM
        y = self.ssm(x)

        # Gate and project
        y = y * F.silu(z)
        return self.out_proj(y)

Experimental Results

Language Modeling (The Pile)

ModelParametersPerplexity
Transformer++1.3B8.94
H31.3B9.19
Hyena1.3B9.07
RWKV-41.5B9.02
Mamba1.4B8.63

Mamba achieved lower perplexity than Transformers of comparable size. Notably, at the 2.8B scale, Mamba improved perplexity by approximately 0.3 over Transformer++.

Inference Speed

Sequence LengthTransformerMamba
2K1x1x
8K1x3x faster
64K1x5x faster
1MOOMWorks

Mamba demonstrates an overwhelming speed advantage over Transformers as sequence length increases.

Selective Copying & Induction Heads

The paper validates the effectiveness of the Selection Mechanism through two key synthetic tasks:

  1. Selective Copying: Selectively copying specific tokens from a long sequence — Previous SSMs (S4) fail, Mamba succeeds
  2. Induction Heads: Recognizing "A B ... A → B" patterns — Previous SSMs fail, Mamba succeeds

Both tasks require content-aware reasoning and are impossible to solve without the Selection Mechanism.

Mamba vs Transformer: When to Use Which?

CriterionTransformerMamba
Long sequences△ (O(N²))◎ (O(N))
Inference speed△ (KV Cache grows)◎ (Fixed state)
In-context learning
Training parallelism◎ (Fully parallel)○ (Scan parallelism)
Ecosystem/tooling

Follow-up Research: Mamba-2 and Hybrids

Significant progress has been made since Mamba:

  • Mamba-2 (Dao & Gu, 2024): Proposed the State Space Duality (SSD) framework, revealing the theoretical connection between SSMs and Attention
  • Jamba (AI21, 2024): A Transformer + Mamba hybrid with 256K context
  • Zamba (Zyphra, 2024): Achieved SOTA in small-scale hybrid models

Conclusion

Mamba overcame the limitations of SSMs with the simple yet powerful idea of the Selection Mechanism, achieving performance comparable to or exceeding Transformers at O(N) complexity. The hardware-aware algorithm design was key to translating theoretical efficiency into real-world speed.

While Transformers remain dominant, Mamba and hybrid architectures are gaining increasing traction in domains where long-sequence processing is critical.

Quiz

Q1: What is the key limitation of Transformers that Mamba aims to solve? The O(N²) time and space complexity of self-attention. As sequence length increases, computation and memory grow quadratically, making long-sequence processing inefficient.

Q2: What does "Selective" mean in Selective State Space? It means making the SSM parameters (B, C, Δ) input-dependent, so the model can select which information to store in the state and which to ignore.

Q3: What was the fundamental reason previous SSMs (S4, etc.) underperformed in language modeling?

Their Linear Time-Invariant (LTI) nature. Since parameters were fixed and independent of the input, content-aware reasoning was impossible.

Q4: What is the intuitive meaning of the Δ (delta) parameter in Mamba? It is a "step size" that determines how much the current input is reflected. A large Δ strongly reflects the current input, while a small Δ retains the previous state.

Q5: Why can't Mamba's Selective Scan use the convolution trick from traditional SSMs?

Because the input-dependent parameters mean the system is no longer LTI, so the trick of converting to convolution and computing in parallel via FFT is no longer applicable.

Q6: What is the key strategy that Mamba's hardware optimization shares with FlashAttention?

IO-aware optimization. All computations are fused within SRAM (kernel fusion), and intermediate states are not stored in HBM but recomputed during the backward pass (recomputation).

Q7: What capability does the Selective Copying task validate? The ability to selectively remember and copy specific tokens from a long sequence — a content-aware reasoning capability. SSMs with fixed parameters cannot solve this task.

Q8: What is State Space Duality (SSD), the key contribution of Mamba-2?A framework showing that SSMs and Attention are mathematically different representations of the same operation. It provides a theoretical foundation for combining the strengths of both approaches.