Skip to content
Published on

Mamba Paper Review: Going Beyond Transformers with Selective State Space Models

Authors
  • Name
    Twitter

1. Introduction: The Limitations of Transformers and the Rise of SSMs

The Transformer architecture has dominated nearly every field of sequence modeling — NLP, Vision, Audio, and more — since its introduction in 2017. However, it has fundamental limitations:

  • O(N squared) complexity: The time and space complexity of Self-Attention grows quadratically with sequence length
  • Inference inefficiency: Each token generation requires referencing the entire KV Cache
  • Difficulty with long sequences: Memory bottleneck at 128K+ contexts

State Space Models (SSMs) have been proposed as an alternative to overcome these limitations. In particular, Mamba (Gu & Dao, ICLR 2024) introduced a selective mechanism to SSMs, achieving performance comparable to Transformers at O(N) complexity.

2. Background: Structured State Space Models (S4)

2.1 Continuous-Time SSM

SSMs are discretized continuous-time systems:

h(t)=Ah(t)+Bx(t)h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t) y(t)=Ch(t)y(t) = \mathbf{C}h(t)
  • h(t)RNh(t) \in \mathbb{R}^N: hidden state
  • x(t)Rx(t) \in \mathbb{R}: input
  • y(t)Ry(t) \in \mathbb{R}: output
  • ARN×N\mathbf{A} \in \mathbb{R}^{N \times N}, BRN×1\mathbf{B} \in \mathbb{R}^{N \times 1}, CR1×N\mathbf{C} \in \mathbb{R}^{1 \times N}

2.2 Discretization (Zero-Order Hold)

Converting the continuous system to discrete time:

A=exp(ΔA)\overline{\mathbf{A}} = \exp(\Delta \mathbf{A}) B=(ΔA)1(exp(ΔA)I)ΔB\overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}

Discretized recurrence relation:

ht=Aht1+Bxth_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_t yt=Chty_t = \mathbf{C} h_t

2.3 The Key to S4: HiPPO Initialization

The core contribution of S4 (Structured State Spaces for Sequence Modeling) was initializing A with the HiPPO (High-order Polynomial Projection Operator) matrix:

import torch
import numpy as np

def make_hippo(N):
    """Generate HiPPO-LegS matrix"""
    P = np.sqrt(1 + 2 * np.arange(N))
    A = np.zeros((N, N))
    for i in range(N):
        for j in range(N):
            if i > j:
                A[i, j] = P[i] * P[j]
            elif i == j:
                A[i, j] = i + 1
            # i < j: 0
    return -A  # Negative for stability

2.4 Limitations of S4

S4 uses fixed, input-independent parameters (A, B, C, delta). This causes two problems:

  1. Inability for content-based reasoning: Performance degrades on tasks requiring different processing based on input content (e.g., selective copying, induction heads)
  2. Loss of Attention's advantages: The content-aware matching capability of Transformers is sacrificed

3. Mamba: Selective State Space Models

3.1 Core Idea: Selection Mechanism

Mamba's key innovation is making SSM parameters dependent on the input:

S4:    (A, B, C, Δ) = fixed parameters
Mamba: (B, C, Δ) = f(input)  ← input-dependent!

Specifically:

Bt=LinearB(xt),Ct=LinearC(xt),Δt=softplus(LinearΔ(xt))\mathbf{B}_t = \text{Linear}_B(x_t), \quad \mathbf{C}_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))

This allows the model to dynamically decide what information to remember and what to forget based on the input.

3.2 Intuition Behind Selection

Δt\Delta_t (step size) is the key:

  • Large Δt\Delta_t: AI\overline{\mathbf{A}} \approx \mathbf{I} → retains previous state (ignores input)
  • Small Δt\Delta_t: focuses more on current input (updates state)

This is analogous to a gating mechanism:

LSTM's forget gate ≈ Mamba's Δ

3.3 Hardware-Aware Algorithm

Introducing selection makes the parameters input-dependent, which means the convolution trick can no longer be used. This risks regressing to O(N squared) complexity during training.

Mamba solves this with kernel fusion + recomputation:

# Mamba's Selective Scan (simplified)
def selective_scan(x, delta, A, B, C):
    """
    x: (B, L, D)  - input
    delta: (B, L, D) - step size (input-dependent)
    A: (D, N)      - state matrix
    B: (B, L, N)   - input matrix (input-dependent)
    C: (B, L, N)   - output matrix (input-dependent)
    """
    B_batch, L, D = x.shape
    N = A.shape[1]

    # Discretization
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
    deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, D, N)

    # Sequential scan (parallel scan used during training)
    h = torch.zeros(B_batch, D, N, device=x.device)
    ys = []
    for t in range(L):
        h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]
        y = (h * C[:, t, None, :]).sum(-1)  # (B, D)
        ys.append(y)

    return torch.stack(ys, dim=1)  # (B, L, D)

In the actual implementation, all intermediate states are kept in SRAM within the CUDA kernel to minimize HBM access (an IO-aware approach similar to FlashAttention).

3.4 Mamba Block Architecture

Input
  ├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐
  │                                                                     │
  └──── Linear(DED) ──── SiLU ──────────────────────────── × ───────┘
                                                          Linear(EDD)
                                                            Output
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        d_inner = d_model * expand

        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,
                                padding=d_conv-1, groups=d_inner)

        # SSM parameters
        self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False)  # B, C, Δ
        self.dt_proj = nn.Linear(1, d_inner, bias=True)

        A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(d_inner))

        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        # x: (B, L, D)
        xz = self.in_proj(x)  # (B, L, 2*ED)
        x, z = xz.chunk(2, dim=-1)

        # Conv + SiLU
        x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]
        x = x.transpose(1, 2)
        x = F.silu(x)

        # SSM
        A = -torch.exp(self.A_log)
        B_C_dt = self.x_proj(x)
        # ... selective scan ...

        # Gate
        y = y * F.silu(z)
        return self.out_proj(y)

4. Mamba-2: State Space Duality

4.1 SSD (State Space Dual) Model

Mamba-2 (Dao & Gu, ICML 2024) discovered a mathematical duality between SSMs and Attention:

  • SSM perspective: Recursive state updates → O(N) inference
  • Attention perspective: Structured matrix multiplication → Parallel training
SSM recurrence: h_t = A_t h_{t-1} + B_t x_t
                 y_t = C_t h_t

Dual
Structured Attention: Y = (TM) X
  where T = lower-triangular causal mask
        M = structured (semi-separable) matrix

4.2 Performance Comparison

ModelParametersPile (ppl)Training FLOPS/sInference (tok/s)
Transformer++2.7B6.71.0x1.0x
Mamba-12.8B6.21.3x5.2x
Mamba-22.7B6.12.1x5.5x

Key improvements:

  • 2x training speed improvement (leveraging structured matrix multiplication)
  • Larger state sizes possible (N=64 → N=256)
  • Introduction of multi-head structure

5. Practical Usage: Working with Mamba Models

5.1 Installation and Inference

pip install mamba-ssm causal-conv1d>=1.4.0
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer

# Load Mamba-2.8B
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = model.cuda().half()

# Inference
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=100, temperature=0.7)
print(tokenizer.decode(output[0]))

5.2 Using Mamba-2

from mamba_ssm import Mamba2

# Using a standalone Mamba-2 layer
layer = Mamba2(
    d_model=2048,
    d_state=128,    # Mamba-2 supports larger states
    d_conv=4,
    expand=2,
    headdim=64,     # Multi-head
).cuda()

x = torch.randn(1, 1024, 2048).cuda()  # (batch, seq_len, d_model)
y = layer(x)  # (1, 1024, 2048)

6. Mamba vs Transformer Comparison

PropertyTransformerMamba
Training ComplexityO(N squared)O(N)
Inference ComplexityO(N) per token (KV Cache)O(1) per token
Memory (Inference)O(N) KV CacheO(1) fixed state
In-context learningStrongWeaker (improving)
Long sequencesBottleneckEfficient
Hardware utilizationOptimized for parallelismRecurrence bottleneck (improved in Mamba-2)

7. Limitations and Outlook

7.1 Current Limitations

  • In-context learning: Information loss when compressing unlimited context into a fixed-size state
  • Retrieval tasks: Attention has the advantage when exact token retrieval is needed
  • Training stability: Gradient instability issues with long sequences

7.2 Hybrid Approaches

The recent trend is Mamba + Attention hybrids:

  • Jamba (AI21): Mixing Mamba layers + Attention layers
  • Zamba (Zyphra): Mamba-based + a small number of shared Attention layers

8. Quiz

Q1. What is the key improvement of Mamba over S4?

Making the SSM parameters (B, C, delta) input-dependent to enable content-based reasoning. S4 used fixed parameters, which made selective processing based on input content impossible.

Q2. What role does delta play in Mamba?

As a step size, it functions as a gating mechanism. When delta is large, the previous state is retained (input is ignored); when small, the model focuses on the current input (state is updated). It is analogous to LSTM's forget gate.

Q3. Why can't the convolution trick be used after introducing Selection?

The convolution trick only works for time-invariant systems. When Selection makes parameters input-dependent, the system becomes time-varying, making global convolution impossible.

Q4. What is the essence of Mamba's Hardware-Aware algorithm?

Keeping intermediate states in SRAM within the CUDA kernel to minimize HBM access. An IO-aware approach similar to FlashAttention that optimizes memory-bound operations.

Q5. What is Mamba-2's State Space Duality?

The mathematical equivalence between SSM's recursive state updates and structured attention matrix multiplication. This allows leveraging parallelized matrix multiplication during training and efficient recursion during inference.

Q6. In which tasks do Transformers outperform Mamba?

In-context learning and exact token retrieval tasks. Attention can directly reference any position within the sequence, whereas Mamba must compress information into a fixed-size state.

Q7. What is the design principle behind hybrid models like Jamba and Zamba?

Combining the O(N) efficiency of Mamba layers with the precise retrieval capability of a small number of Attention layers. Most layers are Mamba, with Attention placed only at key positions.