Skip to content

필사 모드: Mamba Paper Review: Going Beyond Transformers with Selective State Space Models

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

1. Introduction: The Limitations of Transformers and the Rise of SSMs

The Transformer architecture has dominated nearly every field of sequence modeling — NLP, Vision, Audio, and more — since its introduction in 2017. However, it has fundamental limitations:

- **O(N squared) complexity**: The time and space complexity of Self-Attention grows quadratically with sequence length

- **Inference inefficiency**: Each token generation requires referencing the entire KV Cache

- **Difficulty with long sequences**: Memory bottleneck at 128K+ contexts

**State Space Models (SSMs)** have been proposed as an alternative to overcome these limitations. In particular, **Mamba** (Gu & Dao, ICLR 2024) introduced a **selective** mechanism to SSMs, achieving performance comparable to Transformers at **O(N) complexity**.

2. Background: Structured State Space Models (S4)

2.1 Continuous-Time SSM

SSMs are discretized continuous-time systems:

$$

h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)

$$

$$

y(t) = \mathbf{C}h(t)

$$

- $h(t) \in \mathbb{R}^N$: hidden state

- $x(t) \in \mathbb{R}$: input

- $y(t) \in \mathbb{R}$: output

- $\mathbf{A} \in \mathbb{R}^{N \times N}$, $\mathbf{B} \in \mathbb{R}^{N \times 1}$, $\mathbf{C} \in \mathbb{R}^{1 \times N}$

2.2 Discretization (Zero-Order Hold)

Converting the continuous system to discrete time:

$$

\overline{\mathbf{A}} = \exp(\Delta \mathbf{A})

$$

$$

\overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}

$$

Discretized recurrence relation:

$$

h_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_t

$$

$$

y_t = \mathbf{C} h_t

$$

2.3 The Key to S4: HiPPO Initialization

The core contribution of S4 (Structured State Spaces for Sequence Modeling) was initializing A with the **HiPPO (High-order Polynomial Projection Operator)** matrix:

def make_hippo(N):

"""Generate HiPPO-LegS matrix"""

P = np.sqrt(1 + 2 * np.arange(N))

A = np.zeros((N, N))

for i in range(N):

for j in range(N):

if i > j:

A[i, j] = P[i] * P[j]

elif i == j:

A[i, j] = i + 1

i < j: 0

return -A # Negative for stability

2.4 Limitations of S4

S4 uses **fixed, input-independent parameters** (A, B, C, delta). This causes two problems:

1. **Inability for content-based reasoning**: Performance degrades on tasks requiring different processing based on input content (e.g., selective copying, induction heads)

2. **Loss of Attention's advantages**: The content-aware matching capability of Transformers is sacrificed

3. Mamba: Selective State Space Models

3.1 Core Idea: Selection Mechanism

Mamba's key innovation is making **SSM parameters dependent on the input**:

S4: (A, B, C, Δ) = fixed parameters

Mamba: (B, C, Δ) = f(input) ← input-dependent!

Specifically:

$$

\mathbf{B}_t = \text{Linear}_B(x_t), \quad \mathbf{C}_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))

$$

This allows the model to **dynamically decide what information to remember and what to forget** based on the input.

3.2 Intuition Behind Selection

$\Delta_t$ (step size) is the key:

- **Large $\Delta_t$**: $\overline{\mathbf{A}} \approx \mathbf{I}$ → retains previous state (ignores input)

- **Small $\Delta_t$**: focuses more on current input (updates state)

This is analogous to a **gating mechanism**:

LSTM's forget gate ≈ Mamba's Δ

3.3 Hardware-Aware Algorithm

Introducing selection makes the parameters input-dependent, which means the **convolution trick can no longer be used**. This risks regressing to O(N squared) complexity during training.

Mamba solves this with **kernel fusion + recomputation**:

Mamba's Selective Scan (simplified)

def selective_scan(x, delta, A, B, C):

"""

x: (B, L, D) - input

delta: (B, L, D) - step size (input-dependent)

A: (D, N) - state matrix

B: (B, L, N) - input matrix (input-dependent)

C: (B, L, N) - output matrix (input-dependent)

"""

B_batch, L, D = x.shape

N = A.shape[1]

Discretization

deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)

deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)

Sequential scan (parallel scan used during training)

h = torch.zeros(B_batch, D, N, device=x.device)

ys = []

for t in range(L):

h = deltaA[:, t] * h + deltaB[:, t] * x[:, t, :, None]

y = (h * C[:, t, None, :]).sum(-1) # (B, D)

ys.append(y)

return torch.stack(ys, dim=1) # (B, L, D)

In the actual implementation, **all intermediate states are kept in SRAM within the CUDA kernel** to minimize HBM access (an IO-aware approach similar to FlashAttention).

3.4 Mamba Block Architecture

Input

├──── Linear(D → 2ED) ──── SiLU ──── Conv1d ──── SiLU ──── SSM ────┐

│ │

└──── Linear(D → ED) ──── SiLU ──────────────────────────── × ───────┘

Linear(ED → D)

Output

class MambaBlock(nn.Module):

def __init__(self, d_model, d_state=16, d_conv=4, expand=2):

super().__init__()

d_inner = d_model * expand

self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv,

padding=d_conv-1, groups=d_inner)

SSM parameters

self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, Δ

self.dt_proj = nn.Linear(1, d_inner, bias=True)

A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)

self.A_log = nn.Parameter(torch.log(A))

self.D = nn.Parameter(torch.ones(d_inner))

self.out_proj = nn.Linear(d_inner, d_model, bias=False)

def forward(self, x):

x: (B, L, D)

xz = self.in_proj(x) # (B, L, 2*ED)

x, z = xz.chunk(2, dim=-1)

Conv + SiLU

x = self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]]

x = x.transpose(1, 2)

x = F.silu(x)

SSM

A = -torch.exp(self.A_log)

B_C_dt = self.x_proj(x)

... selective scan ...

Gate

y = y * F.silu(z)

return self.out_proj(y)

4. Mamba-2: State Space Duality

4.1 SSD (State Space Dual) Model

Mamba-2 (Dao & Gu, ICML 2024) discovered a mathematical duality between SSMs and Attention:

- **SSM perspective**: Recursive state updates → O(N) inference

- **Attention perspective**: Structured matrix multiplication → Parallel training

SSM recurrence: h_t = A_t h_{t-1} + B_t x_t

y_t = C_t h_t

↕ Dual ↕

Structured Attention: Y = (T ⊙ M) X

where T = lower-triangular causal mask

M = structured (semi-separable) matrix

4.2 Performance Comparison

| Model | Parameters | Pile (ppl) | Training FLOPS/s | Inference (tok/s) |

| ------------- | ---------- | ---------- | ---------------- | ----------------- |

| Transformer++ | 2.7B | 6.7 | 1.0x | 1.0x |

| Mamba-1 | 2.8B | 6.2 | 1.3x | 5.2x |

| Mamba-2 | 2.7B | 6.1 | **2.1x** | **5.5x** |

Key improvements:

- 2x training speed improvement (leveraging structured matrix multiplication)

- Larger state sizes possible (N=64 → N=256)

- Introduction of multi-head structure

5. Practical Usage: Working with Mamba Models

5.1 Installation and Inference

pip install mamba-ssm causal-conv1d>=1.4.0

from mamba_ssm import MambaLMHeadModel

from transformers import AutoTokenizer

Load Mamba-2.8B

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = model.cuda().half()

Inference

input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()

output = model.generate(input_ids, max_length=100, temperature=0.7)

print(tokenizer.decode(output[0]))

5.2 Using Mamba-2

from mamba_ssm import Mamba2

Using a standalone Mamba-2 layer

layer = Mamba2(

d_model=2048,

d_state=128, # Mamba-2 supports larger states

d_conv=4,

expand=2,

headdim=64, # Multi-head

).cuda()

x = torch.randn(1, 1024, 2048).cuda() # (batch, seq_len, d_model)

y = layer(x) # (1, 1024, 2048)

6. Mamba vs Transformer Comparison

| Property | Transformer | Mamba |

| -------------------- | ------------------------- | ------------------------------------------- |

| Training Complexity | O(N squared) | O(N) |

| Inference Complexity | O(N) per token (KV Cache) | O(1) per token |

| Memory (Inference) | O(N) KV Cache | O(1) fixed state |

| In-context learning | Strong | Weaker (improving) |

| Long sequences | Bottleneck | Efficient |

| Hardware utilization | Optimized for parallelism | Recurrence bottleneck (improved in Mamba-2) |

7. Limitations and Outlook

7.1 Current Limitations

- **In-context learning**: Information loss when compressing unlimited context into a fixed-size state

- **Retrieval tasks**: Attention has the advantage when exact token retrieval is needed

- **Training stability**: Gradient instability issues with long sequences

7.2 Hybrid Approaches

The recent trend is **Mamba + Attention hybrids**:

- **Jamba** (AI21): Mixing Mamba layers + Attention layers

- **Zamba** (Zyphra): Mamba-based + a small number of shared Attention layers

8. Quiz

Making the SSM parameters (B, C, delta) **input-dependent** to enable content-based reasoning. S4 used fixed parameters, which made selective processing based on input content impossible.

As a step size, it functions as a **gating mechanism**. When delta is large, the previous state is retained (input is ignored); when small, the model focuses on the current input (state is updated). It is analogous to LSTM's forget gate.

The convolution trick only works for **time-invariant** systems. When Selection makes parameters input-dependent, the system becomes time-varying, making global convolution impossible.

**Keeping intermediate states in SRAM** within the CUDA kernel to minimize HBM access. An IO-aware approach similar to FlashAttention that optimizes memory-bound operations.

The mathematical equivalence between SSM's recursive state updates and structured attention matrix multiplication. This allows leveraging parallelized matrix multiplication during training and efficient recursion during inference.

In-context learning and exact token retrieval tasks. Attention can directly reference any position within the sequence, whereas Mamba must compress information into a fixed-size state.

Combining the O(N) efficiency of Mamba layers with the precise retrieval capability of a small number of Attention layers. Most layers are Mamba, with Attention placed only at key positions.

현재 단락 (1/170)

The Transformer architecture has dominated nearly every field of sequence modeling — NLP, Vision, Au...

작성 글자: 0원문 글자: 8,686작성 단락: 0/170