Skip to content

필사 모드: Mamba: Linear-Time Sequence Modeling with Selective State Spaces — Paper Analysis

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Paper Overview

- **Title**: Mamba: Linear-Time Sequence Modeling with Selective State Spaces

- **Authors**: Albert Gu, Tri Dao (Carnegie Mellon University, Princeton University)

- **Published**: December 2023 (arXiv: 2312.00752), Accepted at ICML 2024

- **Code**: [github.com/state-spaces/mamba](https://github.com/state-spaces/mamba)

Motivation: The Limitations of Transformers

Transformers dominate sequence modeling, but they suffer from fundamental problems:

- **O(N²) Complexity**: Self-attention has quadratic complexity with respect to sequence length

- **Long Sequence Bottleneck**: Memory and compute explode beyond 128K context length

- **Inference Inefficiency**: KV Cache grows proportionally with sequence length

State Space Models (SSMs) theoretically achieve O(N) complexity, but prior SSMs (S4, H3, etc.) fell short of Transformers in language modeling due to their **fixed, input-independent parameters**.

Core Idea: Selective State Spaces

The Problem with Existing SSMs

Traditional SSMs are **Linear Time-Invariant (LTI)** systems:

$$h'(t) = Ah(t) + Bx(t)$$

$$y(t) = Ch(t)$$

Here, A, B, and C are **fixed** matrices. Whether the input is "hello" or "world," the system processes it the same way. This means **content-aware reasoning** is impossible.

Mamba's Solution: Selection Mechanism

Mamba's key innovation is making **B, C, and Δ input-dependent**:

Traditional SSM: fixed parameters

B = nn.Parameter(torch.randn(N)) # Learned but input-independent

Mamba: input-dependent parameters

B = nn.Linear(D, N)(x) # B varies depending on x

C = nn.Linear(D, N)(x) # C varies depending on x

delta = nn.Linear(D, D)(x) # Step size also varies depending on x

This is what **Selective** State Space means. Depending on the input, the model **selects** which information to store in its state and which to ignore.

Intuitive Understanding

The Selection Mechanism can be understood intuitively as:

- **Δ (delta)**: "How much attention should this token receive?" — A large Δ strongly reflects the current input, while a small Δ retains the previous state

- **B**: "Which parts of the input should be recorded in the state?"

- **C**: "Which parts of the state should be extracted as output?"

This is similar to what Transformer attention does, but at **O(N)** complexity.

Hardware Optimization: The Selective Scan Algorithm

Since the Selective SSM is input-dependent, it cannot leverage the efficient convolution tricks used by traditional SSMs. Mamba addresses this through a **hardware-aware algorithm**.

The Problem: HBM ↔ SRAM Bottleneck

GPU memory hierarchy:

┌────────────────────┐

│ SRAM (20MB) │ ← Very fast, very small

├────────────────────┤

│ HBM (80GB) │ ← Slow, large

└────────────────────┘

A naive implementation repeatedly stores and loads intermediate states to/from HBM, becoming **memory-bound**.

Solution: Kernel Fusion + Recomputation

Mamba's Selective Scan algorithm:

1. **Compute entirely in SRAM**: Fuse discretization, selective scan, and output computation into a single kernel

2. **No intermediate state storage**: Store only (N, B, C, Δ) in HBM during the forward pass; recompute intermediate states during the backward pass

3. **Result**: IO-aware optimization similar to FlashAttention

Simplified Selective Scan (conceptual code)

def selective_scan(x, A, B, C, delta):

"""

x: (batch, length, d_model)

Returns: (batch, length, d_model)

"""

batch, length, d = x.shape

n = A.shape[1] # state dimension

Discretize A, B using delta

deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)

deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)

Sequential scan (parallelized via CUDA kernel in practice)

h = torch.zeros(batch, d, n, device=x.device)

ys = []

for i in range(length):

h = deltaA[:, i] * h + deltaB_x[:, i]

y = (h * C[:, i].unsqueeze(1)).sum(-1)

ys.append(y)

return torch.stack(ys, dim=1)

Mamba Block Architecture

Mamba draws inspiration from the H3 architecture and uses a simplified block design:

Input (x)

├──── Linear projection (expand) ────┐

│ │

▼ ▼

Conv1D SiLU (gate)

│ │

▼ │

SiLU │

│ │

▼ │

SSM (selective scan) │

│ │

▼ │

×─────────────────────────────────────┘

Linear projection (reduce)

Output

Implemented in PyTorch:

class MambaBlock(nn.Module):

def __init__(self, d_model, d_state=16, d_conv=4, expand=2):

super().__init__()

d_inner = d_model * expand

Input projection

self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

Conv1D

self.conv1d = nn.Conv1d(

d_inner, d_inner, d_conv,

padding=d_conv - 1, groups=d_inner

)

SSM parameters

self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, delta

self.dt_proj = nn.Linear(1, d_inner, bias=True)

A parameter (structured)

A = torch.arange(1, d_state + 1).float().repeat(d_inner, 1)

self.A_log = nn.Parameter(torch.log(A))

self.D = nn.Parameter(torch.ones(d_inner))

self.out_proj = nn.Linear(d_inner, d_model, bias=False)

def forward(self, x):

batch, length, _ = x.shape

Dual path

xz = self.in_proj(x)

x, z = xz.chunk(2, dim=-1)

Conv + activation

x = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)

x = F.silu(x)

SSM

y = self.ssm(x)

Gate and project

y = y * F.silu(z)

return self.out_proj(y)

Experimental Results

Language Modeling (The Pile)

| Model | Parameters | Perplexity |

| ------------- | ---------- | ---------- |

| Transformer++ | 1.3B | 8.94 |

| H3 | 1.3B | 9.19 |

| Hyena | 1.3B | 9.07 |

| RWKV-4 | 1.5B | 9.02 |

| **Mamba** | **1.4B** | **8.63** |

Mamba achieved **lower perplexity** than Transformers of comparable size. Notably, at the 2.8B scale, Mamba improved perplexity by approximately **0.3** over Transformer++.

Inference Speed

| Sequence Length | Transformer | Mamba |

| --------------- | ----------- | ------------- |

| 2K | 1x | 1x |

| 8K | 1x | 3x faster |

| 64K | 1x | **5x faster** |

| 1M | OOM | Works |

Mamba demonstrates an overwhelming speed advantage over Transformers as sequence length increases.

Selective Copying & Induction Heads

The paper validates the effectiveness of the Selection Mechanism through two key synthetic tasks:

1. **Selective Copying**: Selectively copying specific tokens from a long sequence — Previous SSMs (S4) fail, Mamba succeeds

2. **Induction Heads**: Recognizing "A B ... A → B" patterns — Previous SSMs fail, Mamba succeeds

Both tasks require **content-aware reasoning** and are impossible to solve without the Selection Mechanism.

Mamba vs Transformer: When to Use Which?

| Criterion | Transformer | Mamba |

| -------------------- | ------------------ | -------------------- |

| Long sequences | △ (O(N²)) | ◎ (O(N)) |

| Inference speed | △ (KV Cache grows) | ◎ (Fixed state) |

| In-context learning | ◎ | ○ |

| Training parallelism | ◎ (Fully parallel) | ○ (Scan parallelism) |

| Ecosystem/tooling | ◎ | △ |

Follow-up Research: Mamba-2 and Hybrids

Significant progress has been made since Mamba:

- **Mamba-2** (Dao & Gu, 2024): Proposed the **State Space Duality (SSD)** framework, revealing the theoretical connection between SSMs and Attention

- **Jamba** (AI21, 2024): A Transformer + Mamba hybrid with 256K context

- **Zamba** (Zyphra, 2024): Achieved SOTA in small-scale hybrid models

Conclusion

Mamba overcame the limitations of SSMs with the simple yet powerful idea of the **Selection Mechanism**, achieving performance comparable to or exceeding Transformers at O(N) complexity. The hardware-aware algorithm design was key to translating theoretical efficiency into real-world speed.

While Transformers remain dominant, Mamba and hybrid architectures are gaining increasing traction in domains where long-sequence processing is critical.

Quiz

The O(N²) time and space complexity of self-attention. As sequence length increases, computation

and memory grow quadratically, making long-sequence processing inefficient.

It means making the SSM parameters (B, C, Δ) input-dependent, so the model can select which

information to store in the state and which to ignore.

Q3: What was the fundamental reason previous SSMs (S4, etc.) underperformed in language

modeling?

Their Linear Time-Invariant (LTI) nature. Since parameters were fixed and independent of the

input, content-aware reasoning was impossible.

It is a "step size" that determines how much the current input is reflected. A large Δ strongly

reflects the current input, while a small Δ retains the previous state.

Q5: Why can't Mamba's Selective Scan use the convolution trick from traditional SSMs?

Because the input-dependent parameters mean the system is no longer LTI, so the trick of

converting to convolution and computing in parallel via FFT is no longer applicable.

Q6: What is the key strategy that Mamba's hardware optimization shares with FlashAttention?

IO-aware optimization. All computations are fused within SRAM (kernel fusion), and intermediate

states are not stored in HBM but recomputed during the backward pass (recomputation).

The ability to selectively remember and copy specific tokens from a long sequence — a

content-aware reasoning capability. SSMs with fixed parameters cannot solve this task.

framework showing that SSMs and Attention are mathematically different representations of the same

operation. It provides a theoretical foundation for combining the strengths of both approaches.

현재 단락 (1/151)

- **Title**: Mamba: Linear-Time Sequence Modeling with Selective State Spaces

작성 글자: 0원문 글자: 7,945작성 단락: 0/151