필사 모드: Mamba: Linear-Time Sequence Modeling with Selective State Spaces — Paper Analysis
EnglishPaper Overview
- **Title**: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
- **Authors**: Albert Gu, Tri Dao (Carnegie Mellon University, Princeton University)
- **Published**: December 2023 (arXiv: 2312.00752), Accepted at ICML 2024
- **Code**: [github.com/state-spaces/mamba](https://github.com/state-spaces/mamba)
Motivation: The Limitations of Transformers
Transformers dominate sequence modeling, but they suffer from fundamental problems:
- **O(N²) Complexity**: Self-attention has quadratic complexity with respect to sequence length
- **Long Sequence Bottleneck**: Memory and compute explode beyond 128K context length
- **Inference Inefficiency**: KV Cache grows proportionally with sequence length
State Space Models (SSMs) theoretically achieve O(N) complexity, but prior SSMs (S4, H3, etc.) fell short of Transformers in language modeling due to their **fixed, input-independent parameters**.
Core Idea: Selective State Spaces
The Problem with Existing SSMs
Traditional SSMs are **Linear Time-Invariant (LTI)** systems:
$$h'(t) = Ah(t) + Bx(t)$$
$$y(t) = Ch(t)$$
Here, A, B, and C are **fixed** matrices. Whether the input is "hello" or "world," the system processes it the same way. This means **content-aware reasoning** is impossible.
Mamba's Solution: Selection Mechanism
Mamba's key innovation is making **B, C, and Δ input-dependent**:
Traditional SSM: fixed parameters
B = nn.Parameter(torch.randn(N)) # Learned but input-independent
Mamba: input-dependent parameters
B = nn.Linear(D, N)(x) # B varies depending on x
C = nn.Linear(D, N)(x) # C varies depending on x
delta = nn.Linear(D, D)(x) # Step size also varies depending on x
This is what **Selective** State Space means. Depending on the input, the model **selects** which information to store in its state and which to ignore.
Intuitive Understanding
The Selection Mechanism can be understood intuitively as:
- **Δ (delta)**: "How much attention should this token receive?" — A large Δ strongly reflects the current input, while a small Δ retains the previous state
- **B**: "Which parts of the input should be recorded in the state?"
- **C**: "Which parts of the state should be extracted as output?"
This is similar to what Transformer attention does, but at **O(N)** complexity.
Hardware Optimization: The Selective Scan Algorithm
Since the Selective SSM is input-dependent, it cannot leverage the efficient convolution tricks used by traditional SSMs. Mamba addresses this through a **hardware-aware algorithm**.
The Problem: HBM ↔ SRAM Bottleneck
GPU memory hierarchy:
┌────────────────────┐
│ SRAM (20MB) │ ← Very fast, very small
├────────────────────┤
│ HBM (80GB) │ ← Slow, large
└────────────────────┘
A naive implementation repeatedly stores and loads intermediate states to/from HBM, becoming **memory-bound**.
Solution: Kernel Fusion + Recomputation
Mamba's Selective Scan algorithm:
1. **Compute entirely in SRAM**: Fuse discretization, selective scan, and output computation into a single kernel
2. **No intermediate state storage**: Store only (N, B, C, Δ) in HBM during the forward pass; recompute intermediate states during the backward pass
3. **Result**: IO-aware optimization similar to FlashAttention
Simplified Selective Scan (conceptual code)
def selective_scan(x, A, B, C, delta):
"""
x: (batch, length, d_model)
Returns: (batch, length, d_model)
"""
batch, length, d = x.shape
n = A.shape[1] # state dimension
Discretize A, B using delta
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
Sequential scan (parallelized via CUDA kernel in practice)
h = torch.zeros(batch, d, n, device=x.device)
ys = []
for i in range(length):
h = deltaA[:, i] * h + deltaB_x[:, i]
y = (h * C[:, i].unsqueeze(1)).sum(-1)
ys.append(y)
return torch.stack(ys, dim=1)
Mamba Block Architecture
Mamba draws inspiration from the H3 architecture and uses a simplified block design:
Input (x)
│
├──── Linear projection (expand) ────┐
│ │
▼ ▼
Conv1D SiLU (gate)
│ │
▼ │
SiLU │
│ │
▼ │
SSM (selective scan) │
│ │
▼ │
×─────────────────────────────────────┘
│
▼
Linear projection (reduce)
│
▼
Output
Implemented in PyTorch:
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
Input projection
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
Conv1D
self.conv1d = nn.Conv1d(
d_inner, d_inner, d_conv,
padding=d_conv - 1, groups=d_inner
)
SSM parameters
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, delta
self.dt_proj = nn.Linear(1, d_inner, bias=True)
A parameter (structured)
A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
batch, length, _ = x.shape
Dual path
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
Conv + activation
x = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)
x = F.silu(x)
SSM
y = self.ssm(x)
Gate and project
y = y * F.silu(z)
return self.out_proj(y)
Experimental Results
Language Modeling (The Pile)
| Model | Parameters | Perplexity |
| ------------- | ---------- | ---------- |
| Transformer++ | 1.3B | 8.94 |
| H3 | 1.3B | 9.19 |
| Hyena | 1.3B | 9.07 |
| RWKV-4 | 1.5B | 9.02 |
| **Mamba** | **1.4B** | **8.63** |
Mamba achieved **lower perplexity** than Transformers of comparable size. Notably, at the 2.8B scale, Mamba improved perplexity by approximately **0.3** over Transformer++.
Inference Speed
| Sequence Length | Transformer | Mamba |
| --------------- | ----------- | ------------- |
| 2K | 1x | 1x |
| 8K | 1x | 3x faster |
| 64K | 1x | **5x faster** |
| 1M | OOM | Works |
Mamba demonstrates an overwhelming speed advantage over Transformers as sequence length increases.
Selective Copying & Induction Heads
The paper validates the effectiveness of the Selection Mechanism through two key synthetic tasks:
1. **Selective Copying**: Selectively copying specific tokens from a long sequence — Previous SSMs (S4) fail, Mamba succeeds
2. **Induction Heads**: Recognizing "A B ... A → B" patterns — Previous SSMs fail, Mamba succeeds
Both tasks require **content-aware reasoning** and are impossible to solve without the Selection Mechanism.
Mamba vs Transformer: When to Use Which?
| Criterion | Transformer | Mamba |
| -------------------- | ------------------ | -------------------- |
| Long sequences | △ (O(N²)) | ◎ (O(N)) |
| Inference speed | △ (KV Cache grows) | ◎ (Fixed state) |
| In-context learning | ◎ | ○ |
| Training parallelism | ◎ (Fully parallel) | ○ (Scan parallelism) |
| Ecosystem/tooling | ◎ | △ |
Follow-up Research: Mamba-2 and Hybrids
Significant progress has been made since Mamba:
- **Mamba-2** (Dao & Gu, 2024): Proposed the **State Space Duality (SSD)** framework, revealing the theoretical connection between SSMs and Attention
- **Jamba** (AI21, 2024): A Transformer + Mamba hybrid with 256K context
- **Zamba** (Zyphra, 2024): Achieved SOTA in small-scale hybrid models
Conclusion
Mamba overcame the limitations of SSMs with the simple yet powerful idea of the **Selection Mechanism**, achieving performance comparable to or exceeding Transformers at O(N) complexity. The hardware-aware algorithm design was key to translating theoretical efficiency into real-world speed.
While Transformers remain dominant, Mamba and hybrid architectures are gaining increasing traction in domains where long-sequence processing is critical.
Quiz
The O(N²) time and space complexity of self-attention. As sequence length increases, computation
and memory grow quadratically, making long-sequence processing inefficient.
It means making the SSM parameters (B, C, Δ) input-dependent, so the model can select which
information to store in the state and which to ignore.
Q3: What was the fundamental reason previous SSMs (S4, etc.) underperformed in language
modeling?
Their Linear Time-Invariant (LTI) nature. Since parameters were fixed and independent of the
input, content-aware reasoning was impossible.
It is a "step size" that determines how much the current input is reflected. A large Δ strongly
reflects the current input, while a small Δ retains the previous state.
Q5: Why can't Mamba's Selective Scan use the convolution trick from traditional SSMs?
Because the input-dependent parameters mean the system is no longer LTI, so the trick of
converting to convolution and computing in parallel via FFT is no longer applicable.
Q6: What is the key strategy that Mamba's hardware optimization shares with FlashAttention?
IO-aware optimization. All computations are fused within SRAM (kernel fusion), and intermediate
states are not stored in HBM but recomputed during the backward pass (recomputation).
The ability to selectively remember and copy specific tokens from a long sequence — a
content-aware reasoning capability. SSMs with fixed parameters cannot solve this task.
framework showing that SSMs and Attention are mathematically different representations of the same
operation. It provides a theoretical foundation for combining the strengths of both approaches.
현재 단락 (1/151)
- **Title**: Mamba: Linear-Time Sequence Modeling with Selective State Spaces