- Authors
- Name
- 1. Introduction: The Limitations of Transformers and the Rise of SSMs
- 2. Background: Structured State Space Models (S4)
- 3. Mamba: Selective State Space Models
- 4. Mamba-2: State Space Duality
- 5. Practical Usage: Working with Mamba Models
- 6. Mamba vs Transformer Comparison
- 7. Limitations and Outlook
- 8. Quiz
1. Introduction: The Limitations of Transformers and the Rise of SSMs
The Transformer architecture has dominated nearly every field of sequence modeling — NLP, Vision, Audio, and more — since its introduction in 2017. However, it has fundamental limitations:
- O(N squared) complexity: The time and space complexity of Self-Attention grows quadratically with sequence length
- Inference inefficiency: Each token generation requires referencing the entire KV Cache
- Difficulty with long sequences: Memory bottleneck at 128K+ contexts
State Space Models (SSMs) have been proposed as an alternative to overcome these limitations. In particular, Mamba (Gu & Dao, ICLR 2024) introduced a selective mechanism to SSMs, achieving performance comparable to Transformers at O(N) complexity.
2. Background: Structured State Space Models (S4)
2.1 Continuous-Time SSM
SSMs are discretized continuous-time systems:
- : hidden state
- : input
- : output
- , ,
2.2 Discretization (Zero-Order Hold)
Converting the continuous system to discrete time:
Discretized recurrence relation:
2.3 The Key to S4: HiPPO Initialization
The core contribution of S4 (Structured State Spaces for Sequence Modeling) was initializing A with the HiPPO (High-order Polynomial Projection Operator) matrix:
import torch
import numpy as np
def make_hippo(N):
"""Generate HiPPO-LegS matrix"""
P = np.sqrt(1 + 2 * np.arange(N))
A = np.zeros((N, N))
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = P[i] * P[j]
elif i == j:
A[i, j] = i + 1
# i < j: 0
return -A # Negative for stability
2.4 Limitations of S4
S4 uses fixed, input-independent parameters (A, B, C, delta). This causes two problems:
- Inability for content-based reasoning: Performance degrades on tasks requiring different processing based on input content (e.g., selective copying, induction heads)
- Loss of Attention's advantages: The content-aware matching capability of Transformers is sacrificed
3. Mamba: Selective State Space Models
3.1 Core Idea: Selection Mechanism
Mamba's key innovation is making SSM parameters dependent on the input:
S4: (A, B, C, Δ) = fixed parameters
Mamba: (B, C, Δ) = f(input) ← input-dependent!
Specifically:
This allows the model to dynamically decide what information to remember and what to forget based on the input.
3.2 Intuition Behind Selection
(step size) is the key:
- Large : → retains previous state (ignores input)
- Small : focuses more on current input (updates state)
This is analogous to a gating mechanism:
LSTM's forget gate ≈ Mamba's Δ
3.3 Hardware-Aware Algorithm
Introducing selection makes the parameters input-dependent, which means the convolution trick can no longer be used. This risks regressing to O(N squared) complexity during training.
Mamba solves this with kernel fusion + recomputation:
# Mamba's Selective Scan (simplified)
def selective_scan(x, delta, A, B, C):
"""
x: (B, L, D) - input
delta: (B, L, D) - step size (input-dependent)
A: (D, N) - state matrix
B: (B, L, N) - input matrix (input-dependent)
C: (B, L, N) - output matrix (input-dependent)
"""
B_batch, L, D = x.shape
N = A.shape[1]
# Discretization
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
# Sequential scan (parallel scan used during training)
h = torch.zeros(B_batch, D, N, device=x.device)
ys = []
for t in range(L):
h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]
y = (h * C[:, t, None, :]).sum(-1) # (B, D)
ys.append(y)
return torch.stack(ys, dim=1) # (B, L, D)
In the actual implementation, all intermediate states are kept in SRAM within the CUDA kernel to minimize HBM access (an IO-aware approach similar to FlashAttention).
3.4 Mamba Block Architecture
Input
│
├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐
│ │
└──── Linear(D → ED) ──── SiLU ──────────────────────────── × ───────┘
│
Linear(ED → D)
│
Output
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,
padding=d_conv-1, groups=d_inner)
# SSM parameters
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, Δ
self.dt_proj = nn.Linear(1, d_inner, bias=True)
A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
# x: (B, L, D)
xz = self.in_proj(x) # (B, L, 2*ED)
x, z = xz.chunk(2, dim=-1)
# Conv + SiLU
x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]
x = x.transpose(1, 2)
x = F.silu(x)
# SSM
A = -torch.exp(self.A_log)
B_C_dt = self.x_proj(x)
# ... selective scan ...
# Gate
y = y * F.silu(z)
return self.out_proj(y)
4. Mamba-2: State Space Duality
4.1 SSD (State Space Dual) Model
Mamba-2 (Dao & Gu, ICML 2024) discovered a mathematical duality between SSMs and Attention:
- SSM perspective: Recursive state updates → O(N) inference
- Attention perspective: Structured matrix multiplication → Parallel training
SSM recurrence: h_t = A_t h_{t-1} + B_t x_t
y_t = C_t h_t
↕ Dual ↕
Structured Attention: Y = (T ⊙ M) X
where T = lower-triangular causal mask
M = structured (semi-separable) matrix
4.2 Performance Comparison
| Model | Parameters | Pile (ppl) | Training FLOPS/s | Inference (tok/s) |
|---|---|---|---|---|
| Transformer++ | 2.7B | 6.7 | 1.0x | 1.0x |
| Mamba-1 | 2.8B | 6.2 | 1.3x | 5.2x |
| Mamba-2 | 2.7B | 6.1 | 2.1x | 5.5x |
Key improvements:
- 2x training speed improvement (leveraging structured matrix multiplication)
- Larger state sizes possible (N=64 → N=256)
- Introduction of multi-head structure
5. Practical Usage: Working with Mamba Models
5.1 Installation and Inference
pip install mamba-ssm causal-conv1d>=1.4.0
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
# Load Mamba-2.8B
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = model.cuda().half()
# Inference
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=100, temperature=0.7)
print(tokenizer.decode(output[0]))
5.2 Using Mamba-2
from mamba_ssm import Mamba2
# Using a standalone Mamba-2 layer
layer = Mamba2(
d_model=2048,
d_state=128, # Mamba-2 supports larger states
d_conv=4,
expand=2,
headdim=64, # Multi-head
).cuda()
x = torch.randn(1, 1024, 2048).cuda() # (batch, seq_len, d_model)
y = layer(x) # (1, 1024, 2048)
6. Mamba vs Transformer Comparison
| Property | Transformer | Mamba |
|---|---|---|
| Training Complexity | O(N squared) | O(N) |
| Inference Complexity | O(N) per token (KV Cache) | O(1) per token |
| Memory (Inference) | O(N) KV Cache | O(1) fixed state |
| In-context learning | Strong | Weaker (improving) |
| Long sequences | Bottleneck | Efficient |
| Hardware utilization | Optimized for parallelism | Recurrence bottleneck (improved in Mamba-2) |
7. Limitations and Outlook
7.1 Current Limitations
- In-context learning: Information loss when compressing unlimited context into a fixed-size state
- Retrieval tasks: Attention has the advantage when exact token retrieval is needed
- Training stability: Gradient instability issues with long sequences
7.2 Hybrid Approaches
The recent trend is Mamba + Attention hybrids:
- Jamba (AI21): Mixing Mamba layers + Attention layers
- Zamba (Zyphra): Mamba-based + a small number of shared Attention layers
8. Quiz
Q1. What is the key improvement of Mamba over S4?
Making the SSM parameters (B, C, delta) input-dependent to enable content-based reasoning. S4 used fixed parameters, which made selective processing based on input content impossible.
Q2. What role does delta play in Mamba?
As a step size, it functions as a gating mechanism. When delta is large, the previous state is retained (input is ignored); when small, the model focuses on the current input (state is updated). It is analogous to LSTM's forget gate.
Q3. Why can't the convolution trick be used after introducing Selection?
The convolution trick only works for time-invariant systems. When Selection makes parameters input-dependent, the system becomes time-varying, making global convolution impossible.
Q4. What is the essence of Mamba's Hardware-Aware algorithm?
Keeping intermediate states in SRAM within the CUDA kernel to minimize HBM access. An IO-aware approach similar to FlashAttention that optimizes memory-bound operations.
Q5. What is Mamba-2's State Space Duality?
The mathematical equivalence between SSM's recursive state updates and structured attention matrix multiplication. This allows leveraging parallelized matrix multiplication during training and efficient recursion during inference.
Q6. In which tasks do Transformers outperform Mamba?
In-context learning and exact token retrieval tasks. Attention can directly reference any position within the sequence, whereas Mamba must compress information into a fixed-size state.
Q7. What is the design principle behind hybrid models like Jamba and Zamba?
Combining the O(N) efficiency of Mamba layers with the precise retrieval capability of a small number of Attention layers. Most layers are Mamba, with Attention placed only at key positions.