1. Introduction: The Limitations of Transformers and the Rise of SSMs
The Transformer architecture has dominated nearly every field of sequence modeling — NLP, Vision, Audio, and more — since its introduction in 2017. However, it has fundamental limitations:
- **O(N squared) complexity**: The time and space complexity of Self-Attention grows quadratically with sequence length
- **Inference inefficiency**: Each token generation requires referencing the entire KV Cache
- **Difficulty with long sequences**: Memory bottleneck at 128K+ contexts
**State Space Models (SSMs)** have been proposed as an alternative to overcome these limitations. In particular, **Mamba** (Gu & Dao, ICLR 2024) introduced a **selective** mechanism to SSMs, achieving performance comparable to Transformers at **O(N) complexity**.
2. Background: Structured State Space Models (S4)
2.1 Continuous-Time SSM
SSMs are discretized continuous-time systems:
$$
h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)
$$
$$
y(t) = \mathbf{C}h(t)
$$
- $h(t) \in \mathbb{R}^N$: hidden state
- $x(t) \in \mathbb{R}$: input
- $y(t) \in \mathbb{R}$: output
- $\mathbf{A} \in \mathbb{R}^{N \times N}$, $\mathbf{B} \in \mathbb{R}^{N \times 1}$, $\mathbf{C} \in \mathbb{R}^{1 \times N}$
2.2 Discretization (Zero-Order Hold)
Converting the continuous system to discrete time:
$$
\overline{\mathbf{A}} = \exp(\Delta \mathbf{A})
$$
$$
\overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}
$$
Discretized recurrence relation:
$$
h_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_t
$$
$$
y_t = \mathbf{C} h_t
$$
2.3 The Key to S4: HiPPO Initialization
The core contribution of S4 (Structured State Spaces for Sequence Modeling) was initializing A with the **HiPPO (High-order Polynomial Projection Operator)** matrix:
def make_hippo(N):
"""Generate HiPPO-LegS matrix"""
P = np.sqrt(1 + 2 * np.arange(N))
A = np.zeros((N, N))
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = P[i] * P[j]
elif i == j:
A[i, j] = i + 1
i < j: 0
return -A # Negative for stability
2.4 Limitations of S4
S4 uses **fixed, input-independent parameters** (A, B, C, delta). This causes two problems:
1. **Inability for content-based reasoning**: Performance degrades on tasks requiring different processing based on input content (e.g., selective copying, induction heads)
2. **Loss of Attention's advantages**: The content-aware matching capability of Transformers is sacrificed
3. Mamba: Selective State Space Models
3.1 Core Idea: Selection Mechanism
Mamba's key innovation is making **SSM parameters dependent on the input**:
S4: (A, B, C, Δ) = fixed parameters
Mamba: (B, C, Δ) = f(input) ← input-dependent!
Specifically:
$$
\mathbf{B}_t = \text{Linear}_B(x_t), \quad \mathbf{C}_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))
$$
This allows the model to **dynamically decide what information to remember and what to forget** based on the input.
3.2 Intuition Behind Selection
$\Delta_t$ (step size) is the key:
- **Large $\Delta_t$**: $\overline{\mathbf{A}} \approx \mathbf{I}$ → retains previous state (ignores input)
- **Small $\Delta_t$**: focuses more on current input (updates state)
This is analogous to a **gating mechanism**:
LSTM's forget gate ≈ Mamba's Δ
3.3 Hardware-Aware Algorithm
Introducing selection makes the parameters input-dependent, which means the **convolution trick can no longer be used**. This risks regressing to O(N squared) complexity during training.
Mamba solves this with **kernel fusion + recomputation**:
Mamba's Selective Scan (simplified)
def selective_scan(x, delta, A, B, C):
"""
x: (B, L, D) - input
delta: (B, L, D) - step size (input-dependent)
A: (D, N) - state matrix
B: (B, L, N) - input matrix (input-dependent)
C: (B, L, N) - output matrix (input-dependent)
"""
B_batch, L, D = x.shape
N = A.shape[1]
Discretization
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
Sequential scan (parallel scan used during training)
h = torch.zeros(B_batch, D, N, device=x.device)
ys = []
for t in range(L):
h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]
y = (h * C[:, t, None, :]).sum(-1) # (B, D)
ys.append(y)
return torch.stack(ys, dim=1) # (B, L, D)
In the actual implementation, **all intermediate states are kept in SRAM within the CUDA kernel** to minimize HBM access (an IO-aware approach similar to FlashAttention).
3.4 Mamba Block Architecture
Input
│
├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐
│ │
└──── Linear(D → ED) ──── SiLU ──────────────────────────── × ───────┘
│
Linear(ED → D)
│
Output
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,
padding=d_conv-1, groups=d_inner)
SSM parameters
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, Δ
self.dt_proj = nn.Linear(1, d_inner, bias=True)
A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
x: (B, L, D)
xz = self.in_proj(x) # (B, L, 2*ED)
x, z = xz.chunk(2, dim=-1)
Conv + SiLU
x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]
x = x.transpose(1, 2)
x = F.silu(x)
SSM
A = -torch.exp(self.A_log)
B_C_dt = self.x_proj(x)
... selective scan ...
Gate
y = y * F.silu(z)
return self.out_proj(y)
4. Mamba-2: State Space Duality
4.1 SSD (State Space Dual) Model
Mamba-2 (Dao & Gu, ICML 2024) discovered a mathematical duality between SSMs and Attention:
- **SSM perspective**: Recursive state updates → O(N) inference
- **Attention perspective**: Structured matrix multiplication → Parallel training
SSM recurrence: h_t = A_t h_{t-1} + B_t x_t
y_t = C_t h_t
↕ Dual ↕
Structured Attention: Y = (T ⊙ M) X
where T = lower-triangular causal mask
M = structured (semi-separable) matrix
4.2 Performance Comparison
| Model | Parameters | Pile (ppl) | Training FLOPS/s | Inference (tok/s) |
| ------------- | ---------- | ---------- | ---------------- | ----------------- |
| Transformer++ | 2.7B | 6.7 | 1.0x | 1.0x |
| Mamba-1 | 2.8B | 6.2 | 1.3x | 5.2x |
| Mamba-2 | 2.7B | 6.1 | **2.1x** | **5.5x** |
Key improvements:
- 2x training speed improvement (leveraging structured matrix multiplication)
- Larger state sizes possible (N=64 → N=256)
- Introduction of multi-head structure
5. Practical Usage: Working with Mamba Models
5.1 Installation and Inference
pip install mamba-ssm causal-conv1d>=1.4.0
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
Load Mamba-2.8B
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = model.cuda().half()
Inference
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=100, temperature=0.7)
print(tokenizer.decode(output[0]))
5.2 Using Mamba-2
from mamba_ssm import Mamba2
Using a standalone Mamba-2 layer
layer = Mamba2(
d_model=2048,
d_state=128, # Mamba-2 supports larger states
d_conv=4,
expand=2,
headdim=64, # Multi-head
).cuda()
x = torch.randn(1, 1024, 2048).cuda() # (batch, seq_len, d_model)
y = layer(x) # (1, 1024, 2048)
6. Mamba vs Transformer Comparison
| Property | Transformer | Mamba |
| -------------------- | ------------------------- | ------------------------------------------- |
| Training Complexity | O(N squared) | O(N) |
| Inference Complexity | O(N) per token (KV Cache) | O(1) per token |
| Memory (Inference) | O(N) KV Cache | O(1) fixed state |
| In-context learning | Strong | Weaker (improving) |
| Long sequences | Bottleneck | Efficient |
| Hardware utilization | Optimized for parallelism | Recurrence bottleneck (improved in Mamba-2) |
7. Limitations and Outlook
7.1 Current Limitations
- **In-context learning**: Information loss when compressing unlimited context into a fixed-size state
- **Retrieval tasks**: Attention has the advantage when exact token retrieval is needed
- **Training stability**: Gradient instability issues with long sequences
7.2 Hybrid Approaches
The recent trend is **Mamba + Attention hybrids**:
- **Jamba** (AI21): Mixing Mamba layers + Attention layers
- **Zamba** (Zyphra): Mamba-based + a small number of shared Attention layers
8. Quiz
Making the SSM parameters (B, C, delta) **input-dependent** to enable content-based reasoning. S4 used fixed parameters, which made selective processing based on input content impossible.
As a step size, it functions as a **gating mechanism**. When delta is large, the previous state is retained (input is ignored); when small, the model focuses on the current input (state is updated). It is analogous to LSTM's forget gate.
The convolution trick only works for **time-invariant** systems. When Selection makes parameters input-dependent, the system becomes time-varying, making global convolution impossible.
**Keeping intermediate states in SRAM** within the CUDA kernel to minimize HBM access. An IO-aware approach similar to FlashAttention that optimizes memory-bound operations.
The mathematical equivalence between SSM's recursive state updates and structured attention matrix multiplication. This allows leveraging parallelized matrix multiplication during training and efficient recursion during inference.
In-context learning and exact token retrieval tasks. Attention can directly reference any position within the sequence, whereas Mamba must compress information into a fixed-size state.
Combining the O(N) efficiency of Mamba layers with the precise retrieval capability of a small number of Attention layers. Most layers are Mamba, with Attention placed only at key positions.
현재 단락 (1/170)
The Transformer architecture has dominated nearly every field of sequence modeling — NLP, Vision, Au...