Skip to content

필사 모드: Mamba and State Space Models Complete Guide: Beyond Transformers

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Overview

Since the landmark "Attention is All You Need" paper in 2017, Transformers have dominated virtually every sequence modeling task — from natural language processing to images, code, and audio. However, Transformers have a fundamental weakness: the **quadratic complexity O(n^2) of self-attention**.

In December 2023, **Mamba**, published by Albert Gu and Tri Dao, elegantly solved this problem and sent shockwaves through the deep learning community. Mamba adapts State Space Models (SSMs) — a concept from classical control theory — to modern deep learning, enabling linear-time processing of long sequences.

This guide covers everything from the mathematical foundations of SSMs to Mamba's core innovations and practical implementations.

1. The Limitations of Transformers and the Rise of SSMs

The O(n^2) Complexity Problem of Self-Attention

The core of Transformers is the Self-Attention mechanism. For a sequence of length n, computing the attention matrix requires computing relationships between all pairs of tokens, resulting in both time and memory complexity of **O(n^2)**.

Attention matrix size: n × n

= 1,000 tokens → 1,000,000 elements

= 10,000 tokens → 100,000,000 elements

= 100,000 tokens → 10,000,000,000 elements (infeasible!)

Due to this problem, real-world LLMs face constraints on context window size. GPT-4 supports 128K tokens, but this requires enormous computational cost and optimization techniques.

The Difficulty of Processing Long Sequences

Real applications — such as long document summarization, understanding entire codebases, and maintaining long-term conversations — require very long context. But Transformer's quadratic complexity makes practical use challenging.

For example, processing an entire book (approximately 100,000 words = 130,000 tokens) at once would require memory that is virtually infeasible on current hardware with standard Transformers.

The Limitations of Recurrent Models (RNN, LSTM)

Recurrent neural networks (RNNs) and LSTMs theoretically have O(n) complexity because they maintain a fixed-size hidden state at each timestep. However, RNN/LSTM have the following problems:

- **Vanishing/Exploding Gradients**: Gradients disappear or explode during backpropagation on long sequences

- **Non-parallelizable**: Each timestep depends on the previous state, making GPU parallelization difficult

- **Difficulty capturing long-range dependencies**: Hard to connect information far apart in the sequence

The Solution SSMs Offer

State Space Models (SSMs) combine the advantages of both approaches:

- **During training**: Full parallelization via convolution → training efficiency comparable to Transformers

- **During inference**: Recurrence with O(1) memory and O(n) computation → efficient inference like RNNs

- **Long sequences**: Linear complexity enables processing of very long sequences

2. Mathematical Foundations of State Space Models

Continuous-Time SSM

The origins of SSMs trace back to Rudolf Kalman's control theory in the 1960s. A continuous-time linear system is expressed as:

x'(t) = A·x(t) + B·u(t)

y(t) = C·x(t) + D·u(t)

Where:

- `u(t)`: input signal

- `x(t)`: state vector (latent state), size N

- `y(t)`: output signal

- `A`: N×N state transition matrix

- `B`: N×1 input projection matrix

- `C`: 1×N output projection matrix

- `D`: direct feedthrough term (skip connection, usually set to 0)

This system receives input u(t), updates the internal state x(t), and produces output y(t). The state x(t) can be thought of as a "summary" of past information.

Discretization

To use continuous-time SSMs on digital computers, discretization is required. With a sampling interval Delta, there are two main discretization methods.

**Zero-Order Hold (ZOH)**:

A_bar = exp(Delta · A)

B_bar = (Delta · A)^(-1) · (exp(Delta · A) - I) · Delta · B

**Bilinear (Tustin) Transform**:

A_bar = (I - Delta/2 · A)^(-1) · (I + Delta/2 · A)

B_bar = (I - Delta/2 · A)^(-1) · Delta · B

After discretization, the system becomes:

x[t] = A_bar · x[t-1] + B_bar · u[t]

y[t] = C · x[t]

This form is a **Linear Recurrence**, updating the state at each timestep.

SSM as a Convolution Kernel

The powerful property of SSMs is that **they can be computed as convolutions during training**. Starting with initial state x[0] = 0:

y[0] = C · B_bar · u[0]

y[1] = C · A_bar · B_bar · u[0] + C · B_bar · u[1]

y[2] = C · A_bar^2 · B_bar · u[0] + C · A_bar · B_bar · u[1] + C · B_bar · u[2]

...

Expressing this as convolution kernel K:

K = (C·B_bar, C·A_bar·B_bar, C·A_bar^2·B_bar, ...)

y = K * u (convolution)

This kernel K can be computed in parallel, and using FFT, it can be computed in O(n log n).

This is the fundamental duality of SSMs:

- **Inference**: Recurrence with O(1) memory

- **Training**: Convolution with O(n log n) parallel computation

3. S4 (Structured State Space Sequence Model)

S4 was published by Albert Gu et al. in 2021 and is the first important work that practically applied SSMs to deep learning.

HiPPO Matrix Initialization

One of S4's key contributions is the HiPPO (High-order Polynomial Projection Operators) matrix initialization. Simply initializing A randomly leads to gradient vanishing problems. HiPPO provides a specially designed A matrix that approximates past inputs using polynomials.

HiPPO-LegS (Legendre polynomial-based):

A[n,k] = -sqrt((2n+1)(2k+1)) if n > k

A[n,k] = -(n+1) if n == k

A[n,k] = 0 if n < k

This initialization ensures that state x[t] maintains an optimal polynomial approximation of all past inputs u[0..t].

Structured Matrix A (DPLR)

To efficiently compute the convolution kernel of the A matrix, S4 expresses A in **Diagonal Plus Low-Rank (DPLR)** form:

A = Λ - P·Q^T

Where Λ is a diagonal matrix and P, Q are low-rank vectors. Using this structure, kernel computation can be reduced to O(N).

Efficient Computation

class S4Layer(nn.Module):

"""Simplified implementation of S4 layer"""

def __init__(self, d_model, d_state=64, dropout=0.0):

super().__init__()

self.d_model = d_model

self.d_state = d_state

HiPPO-LegS initialization

A = self._make_hippo(d_state)

DPLR decomposition

self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32))

self.B = nn.Parameter(torch.randn(d_state, 1) * 0.01)

self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)

self.log_delta = nn.Parameter(torch.zeros(d_model))

self.D = nn.Parameter(torch.ones(d_model))

def _make_hippo(self, N):

"""Generate HiPPO-LegS matrix"""

P = np.sqrt(1 + 2 * np.arange(N))

A = P[:, np.newaxis] * P[np.newaxis, :]

A = np.tril(A) - np.diag(np.arange(N))

return -A

def discretize(self, A, B, C, delta):

"""ZOH discretization"""

dA = torch.matrix_exp(delta.unsqueeze(-1) * A)

dB = torch.linalg.solve(A, (dA - torch.eye(A.shape[0])) @ B)

return dA, dB, C

def forward(self, u):

u: (batch, seq_len, d_model)

delta = torch.exp(self.log_delta)

Discretization and convolution computation

Actual implementation is more complex, simplified here for concept

return u

4. H3 (Hungry Hungry Hippos)

Published in 2022, H3 is an improvement of S4 better suited for language modeling. The title is a humorous expansion of the HiPPO acronym.

Differences from S4

S4 captures long-range dependencies well, but lacks the **token-to-token interactions** important in language modeling. Attention directly computes "how related is this word to that word?", but S4 finds this difficult to capture with pure recurrence.

Gating Mechanism

H3 uses two SSMs and adds gating between them:

class H3Layer(nn.Module):

"""H3 Layer"""

def __init__(self, d_model, d_state=64):

super().__init__()

Two SSMs (shift SSM + diagonal SSM)

self.shift_ssm = S4Layer(d_model, d_state)

self.diag_ssm = S4Layer(d_model, d_state)

Projections

self.Q_proj = nn.Linear(d_model, d_model)

self.K_proj = nn.Linear(d_model, d_model)

self.V_proj = nn.Linear(d_model, d_model)

self.out_proj = nn.Linear(d_model, d_model)

def forward(self, x):

Q, K, V projections

q = self.Q_proj(x)

k = self.K_proj(x)

v = self.V_proj(x)

SSM 1: Apply Shift SSM to K

k_ssm = self.shift_ssm(k)

Multiplicative gating: Q element-wise K_ssm

gated = q * k_ssm

SSM 2: Apply Diagonal SSM to V, then multiply with gated

v_ssm = self.diag_ssm(v)

output = gated * v_ssm

return self.out_proj(output)

Language Modeling Improvements

H3 achieved performance close to Transformers in GPT-2 sized language models while showing significantly faster inference speed. This was an important milestone demonstrating the practicality of SSMs for language modeling.

5. Mamba (Selective State Space Model)

In December 2023, Mamba published by Albert Gu and Tri Dao represents the most important advance in SSM research. Mamba's core innovation is the **Selective Mechanism**.

Core Innovation: The Selective Mechanism

The fundamental limitation of S4 and H3 was that matrices A, B, C were **input-independent**. The kernel had to be fixed for convolution-based training.

Mamba breaks this constraint and introduces **S6 (Selective SSM)**: making B, C, and Delta matrices functions of input u(t).

B(t) = Linear_B(u(t)) # B that varies with input

C(t) = Linear_C(u(t)) # C that varies with input

Delta(t) = softplus(Linear_Delta(u(t))) # Delta that varies with input

This allows the model to **dynamically decide**, based on input, which information to store in the state and which to discard.

Intuitively:

- Large Delta → current input strongly influences state (remember information)

- Small Delta → state changes little (ignore information)

This plays a role similar to LSTM's gates (input gate, forget gate), but within the continuous-time SSM framework.

Hardware-Aware Algorithm

Introducing the selective mechanism means B, C depend on input, so convolution can no longer be used for computation. This can lead to severe computational inefficiency.

Mamba solves this with a **Hardware-Aware Algorithm** — a kernel fusion technique leveraging GPU memory hierarchy (HBM vs SRAM):

**Problem**: Storing intermediate states in HBM (slow memory) during per-timestep recurrence creates memory bandwidth bottlenecks.

**Solution**:

- Perform all intermediate computations in fast SRAM (on-chip memory)

- Use **Parallel Scan** algorithm instead of full convolution

- Store only the final output in HBM

Parallel Scan concept (actual implementation in CUDA)

def parallel_scan(gates, tokens):

"""

Compute linear recurrence in parallel

x[t] = gates[t] * x[t-1] + tokens[t]

Achieves O(log n) parallel depth with binary tree structure

"""

n = len(tokens)

Up-sweep phase

log_n = int(np.log2(n))

for d in range(log_n):

step = 2 ** (d + 1)

for i in range(step - 1, n, step):

gates[i] = gates[i] * gates[i - 2**d]

tokens[i] = gates[i - 2**d] * tokens[i - 2**d] + tokens[i]

Down-sweep phase (omitted)

return tokens

Mamba Block Structure

Input x (B, L, D)

├──────────────────────┐

│ │

Linear(D→ED) Linear(D→ED)

+ SiLU activation │

│ SSM (S6)

│ │

└─────── ⊙ ────────────┘

(element-wise multiply)

Linear(ED→D)

Output y (B, L, D)

Where E is the expansion ratio (usually 2) and D is the model dimension.

Complete Mamba Implementation

from einops import rearrange, repeat

class MambaBlock(nn.Module):

"""

Mamba Block implementation

Paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"

"""

def __init__(

self,

d_model, # Model dimension D

d_state=16, # SSM state dimension N

d_conv=4, # Convolution kernel size

expand=2, # Expansion ratio E

dt_rank="auto",

dt_min=0.001,

dt_max=0.1,

dt_init="random",

dt_scale=1.0,

dt_init_floor=1e-4,

bias=False,

):

super().__init__()

self.d_model = d_model

self.d_state = d_state

self.d_conv = d_conv

self.expand = expand

self.d_inner = int(self.expand * self.d_model)

if dt_rank == "auto":

self.dt_rank = max(1, int(d_model / 16))

else:

self.dt_rank = dt_rank

Input projection (D → 2*ED, compute both branches at once)

self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)

Local convolution (depthwise conv)

self.conv1d = nn.Conv1d(

in_channels=self.d_inner,

out_channels=self.d_inner,

bias=True,

kernel_size=d_conv,

groups=self.d_inner,

padding=d_conv - 1, # causal padding

)

SSM parameter projections

self.x_proj = nn.Linear(

self.d_inner, self.dt_rank + self.d_state * 2, bias=False

)

self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

A initialization (HiPPO-based)

A = repeat(

torch.arange(1, self.d_state + 1),

"n -> d n",

d=self.d_inner

)

self.A_log = nn.Parameter(torch.log(A))

self.A_log._no_weight_decay = True

D (skip connection)

self.D = nn.Parameter(torch.ones(self.d_inner))

self.D._no_weight_decay = True

Output projection

self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)

def forward(self, hidden_states):

"""

hidden_states: (B, L, D)

Returns: (B, L, D)

"""

batch, seqlen, dim = hidden_states.shape

Input projection

xz = self.in_proj(hidden_states)

x, z = xz.chunk(2, dim=-1) # each (B, L, ED)

Convolution (causal 1D conv)

x = rearrange(x, "b l d -> b d l")

x = self.conv1d(x)[:, :, :seqlen] # causal trimming

x = rearrange(x, "b d l -> b l d")

x = F.silu(x)

SSM

y = self.ssm(x)

Gating (element-wise multiply with z branch)

y = y * F.silu(z)

Output projection

output = self.out_proj(y)

return output

def ssm(self, x):

"""Selective SSM (S6) computation"""

d_in, n = self.d_inner, self.d_state

A matrix (-exp(-A_log) ensures always negative = stability)

A = -torch.exp(self.A_log.float()) # (ED, N)

x_proj: compute Delta, B, C

x_dbl = self.x_proj(x) # (B, L, dt_rank + 2N)

delta, B, C = x_dbl.split(

[self.dt_rank, n, n], dim=-1

)

Delta: softplus ensures positive values

delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)

ZOH discretization and selective scan

y = self.selective_scan(x, delta, A, B, C, self.D)

return y

def selective_scan(self, u, delta, A, B, C, D):

"""

Selective scan algorithm

In practice, uses CUDA kernels from mamba_ssm

Here shown in pure PyTorch for conceptual understanding

"""

b, l, d_in = u.shape

n = A.shape[1]

Discretization

delta: (B, L, ED), A: (ED, N) → dA: (B, L, ED, N)

deltaA = torch.exp(

torch.einsum("bld,dn->bldn", delta, A)

)

delta: (B, L, ED), B: (B, L, N), u: (B, L, ED)

→ deltaB_u: (B, L, ED, N)

deltaB_u = torch.einsum("bld,bln,bld->bldn", delta, B, u)

Recurrent scan

x = torch.zeros(b, d_in, n, device=u.device, dtype=u.dtype)

ys = []

for i in range(l):

x = deltaA[:, i] * x + deltaB_u[:, i]

y = torch.einsum("bdn,bn->bd", x, C[:, i, :])

ys.append(y)

y = torch.stack(ys, dim=1) # (B, L, ED)

D (skip connection)

y = y + u * D

return y

class MambaModel(nn.Module):

"""Sequence model composed of stacked Mamba blocks"""

def __init__(self, d_model, n_layers, d_state=16, expand=2):

super().__init__()

self.layers = nn.ModuleList([

MambaBlock(d_model, d_state=d_state, expand=expand)

for _ in range(n_layers)

])

self.norm = nn.LayerNorm(d_model)

def forward(self, x):

"""x: (B, L, D)"""

for layer in self.layers:

x = x + layer(self.norm(x)) # Pre-norm residual

return x

6. Mamba 2

In May 2024, Tri Dao and Albert Gu published Mamba 2, further strengthening the theoretical foundations.

Differences from Mamba 1

Mamba 1's selective scan required sequential computation at each timestep. Mamba 2 discovered **State Space Duality (SSD)** to make this even more efficient.

State Space Duality (SSD)

The core insight of Mamba 2: SSMs with a specific structure can be expressed as **Semi-Separable Matrices**, which are mathematically equivalent to a specific form of attention.

This mathematical duality enables:

1. Expressing SSM computation as **matrix multiplication** form

2. Leveraging highly optimized BLAS libraries

3. Maximizing GPU efficiency with Tensor Core utilization

SSD operation concept

SSM recurrence: x[t] = A[t] * x[t-1] + B[t] * u[t]

y[t] = C[t]^T * x[t]

#

Process in chunks:

- Within chunk: parallel computation via matrix multiplication

- Between chunks: recurrence for state propagation

class Mamba2Block(nn.Module):

"""Mamba 2 Block"""

def __init__(self, d_model, d_state=64, n_heads=8, chunk_size=64):

super().__init__()

self.d_model = d_model

self.d_state = d_state

self.n_heads = n_heads

self.chunk_size = chunk_size

self.d_head = d_model // n_heads

Multi-head structure

self.norm = nn.LayerNorm(d_model)

self.in_proj = nn.Linear(d_model, d_model * 2 + d_state * 2 + n_heads)

self.out_proj = nn.Linear(d_model, d_model)

A parameters (per head)

self.A_log = nn.Parameter(torch.randn(n_heads))

def forward(self, x):

"""x: (B, L, D)"""

Actual implementation uses CUDA kernels from mamba_ssm package

return x

Multi-Head Structure

Mamba 2 introduces **multi-head SSM**, making it structurally similar to Transformer's Multi-Head Attention. This provides:

- Better expressiveness

- Theoretical connection to attention

- Easier hybrid architecture design

Integration Potential with Transformers

SSD theory shows that certain SSMs are mathematically equivalent to certain attention mechanisms. This provides the theoretical basis for **hybrid architectures** mixing Mamba and Transformers.

7. Hybrid Architectures

MambaFormer

MambaFormer interleaves Mamba blocks and Transformer blocks:

Layer 1: Mamba Block (capture local patterns)

Layer 2: Attention (capture global dependencies)

Layer 3: Mamba Block

Layer 4: Attention

...

This structure leverages the strengths of each component:

- Mamba: efficient sequence processing, local patterns

- Attention: selective information retrieval, global dependencies

Jamba (SSM + Transformer + MoE)

Jamba, released by AI21 Labs in 2024, is a more complex hybrid:

Jamba = Mamba + Transformer + Mixture of Experts

Architecture:

- 52B parameters (active: 12B)

- Layer ratio: Attention 1 : Mamba 7

- MoE applied to some layers

- Supports 256K context window

Compared to same-size Transformers, Jamba achieves:

- 3x improvement in inference throughput

- Dramatically improved memory efficiency for long contexts

Jamba-style hybrid block (concept)

class JambaLayer(nn.Module):

def __init__(self, d_model, layer_idx, attn_every_n=8):

super().__init__()

self.use_attention = (layer_idx % attn_every_n == 0)

if self.use_attention:

self.mixer = nn.MultiheadAttention(d_model, num_heads=8)

else:

self.mixer = MambaBlock(d_model)

MoE (only for some layers)

self.use_moe = (layer_idx % 2 == 0)

if self.use_moe:

self.ffn = MixtureOfExperts(d_model, n_experts=16)

else:

self.ffn = nn.Sequential(

nn.Linear(d_model, d_model * 4),

nn.SiLU(),

nn.Linear(d_model * 4, d_model)

)

def forward(self, x):

x = x + self.mixer(x)

x = x + self.ffn(x)

return x

RWKV (Linear Attention)

RWKV is a hybrid of Transformer and RNN. Standing for "Receptance Weighted Key Value":

- Training: Parallelized like Transformer (matrix form)

- Inference: Recurrent like RNN (O(1) state)

- Token mixing without attention

Latest versions (RWKV-6, RWKV-7) show performance competitive with Mamba.

RetNet (Retentive Network)

Microsoft's RetNet aims to simultaneously achieve "Training Parallelism, Inference Efficiency, Competitive Performance":

Three computational paradigms:

1. Parallel: O(n^2) parallel computation during training (lower constant than Transformer)

2. Recurrent: O(1) memory during inference

3. Chunkwise Recurrent: balanced intermediate approach

8. Mamba Performance Comparison

Inference Speed (Linear Scale)

Mamba's greatest strength is its **linear scaling with sequence length**.

Normalizing performance at sequence length 2K to 1x:

Sequence Length Transformer Mamba

1K 0.5x 0.5x

2K 1.0x 1.0x

4K 3.5x 2.0x ← 1.75x faster than Transformer

8K 13x 4.0x ← 3.25x faster than Transformer

16K 50x 8.0x ← 6.25x faster than Transformer

100K ~1800x ~50x ← 36x faster than Transformer

Memory Efficiency

State size comparison during inference:

Model Inference Memory (per 1K tokens)

Transformer O(n) KV Cache

Mamba O(1) SSM state (fixed size!)

Example: 130M parameter model, 1M token sequence

- Transformer: ~16GB KV Cache

- Mamba: ~1MB state (constant!)

Long Sequence Tasks

Long Range Arena (LRA) benchmark (sequence lengths 1K-16K):

Model ListOps Text Retrieval Image Path-X Avg

Transformer 36.4 65.0 57.5 42.4 0.0 40.3

LSTM 35.9 63.7 65.0 43.3 0.0 41.6

S4 59.6 86.8 90.9 88.7 86.1 82.4

Mamba ~S4-level performance, faster processing

Mamba particularly excels over S4 on synthetic benchmarks like **Selective Copying** and **Induction Heads**, which are more relevant to real language modeling ability.

9. Applications of Mamba

Natural Language Processing

Mamba-based language models are emerging rapidly:

- **MambaChat**: Conversational AI assistant

- **Falcon Mamba**: Open-source 7B Mamba model released by TII

- **CodeMamba**: Specialized for code generation

Mamba is efficient for long document processing, summarization, and translation compared to Transformers.

Bioinformatics

Genomic sequence analysis deals with very long sequences (millions of base pairs), making Mamba particularly advantageous:

- **Caduceus**: Long DNA sequence modeling

- **Hyena**: Long-sequence DNA/protein analysis

- Potential applications in protein structure prediction

Example of Mamba for biological sequence modeling

from mamba_ssm import MambaLMHeadModel

DNA sequence modeling (A, T, G, C, N tokens)

DNA_VOCAB = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}

VOCAB_SIZE = len(DNA_VOCAB)

Process very long sequences efficiently

model = MambaLMHeadModel.from_pretrained(

"state-spaces/mamba-130m",

device="cuda",

dtype=torch.float16

)

Time Series Analysis

For financial, meteorological, IoT sensor data, and other long time series, Mamba shows its strengths:

- **TimeMamba**: Long time series forecasting

- **MambaMixer**: Multivariate time series modeling

- Advancements in S4/S5-based time series models

Image Processing (VMamba)

VMamba extends Mamba to 2D image processing:

VMamba core: 2D selective scan

Scan image in 4 directions to capture 2D structure

Directions:

1. Left→Right, Top→Bottom (standard raster scan)

2. Right→Left, Bottom→Top (reverse)

3. Top→Bottom, Left→Right (column-first)

4. Bottom→Top, Right→Left (reverse)

class VMambaBlock(nn.Module):

"""VMamba: Visual Mamba Block"""

def __init__(self, d_model, d_state=16):

super().__init__()

self.norm = nn.LayerNorm(d_model)

4-directional SSMs

self.ssms = nn.ModuleList([

MambaBlock(d_model, d_state) for _ in range(4)

])

self.out_proj = nn.Linear(d_model * 4, d_model)

def forward(self, x):

"""x: (B, H, W, D) - image patch embeddings"""

b, h, w, d = x.shape

x_flat = x.view(b, h*w, d)

outputs = []

4-directional scan

for i, ssm in enumerate(self.ssms):

if i == 0: # forward

seq = x_flat

elif i == 1: # reverse

seq = x_flat.flip(1)

elif i == 2: # column-first

seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d)

else: # column-first reverse

seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d).flip(1)

out = ssm(seq)

if i % 2 == 1:

out = out.flip(1)

outputs.append(out)

Combine 4 directional results

combined = torch.cat(outputs, dim=-1)

return self.out_proj(combined).view(b, h, w, d)

10. Practical Usage

Installing mamba-ssm Package

Required packages

pip install torch torchvision torchaudio

pip install causal-conv1d>=1.2.0

pip install mamba-ssm

Or install from source (latest features)

git clone https://github.com/state-spaces/mamba

cd mamba

pip install -e ".[dev]"

Using MambaLMHeadModel

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

from transformers import AutoTokenizer

Load pre-trained model

model_name = "state-spaces/mamba-2.8b-slimpj"

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = MambaLMHeadModel.from_pretrained(

model_name,

device="cuda",

dtype=torch.bfloat16

)

model.eval()

Text generation

def generate_text(prompt, max_new_tokens=200, temperature=0.7):

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

with torch.no_grad():

output = model.generate(

input_ids=input_ids,

max_new_tokens=max_new_tokens,

temperature=temperature,

top_p=0.9,

do_sample=True,

eos_token_id=tokenizer.eos_token_id,

)

generated = output[0][input_ids.shape[1]:]

return tokenizer.decode(generated, skip_special_tokens=True)

Example

prompt = "State Space Models are powerful because"

result = generate_text(prompt)

print(result)

Fine-tuning Example

from torch.utils.data import DataLoader, Dataset

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

from transformers import AutoTokenizer, get_linear_schedule_with_warmup

class TextDataset(Dataset):

def __init__(self, texts, tokenizer, max_length=512):

self.encodings = tokenizer(

texts,

truncation=True,

max_length=max_length,

padding="max_length",

return_tensors="pt"

)

def __len__(self):

return len(self.encodings['input_ids'])

def __getitem__(self, idx):

return {

'input_ids': self.encodings['input_ids'][idx],

'attention_mask': self.encodings['attention_mask'][idx],

}

def finetune_mamba(

model_name="state-spaces/mamba-130m",

texts=None,

num_epochs=3,

learning_rate=1e-4,

batch_size=8,

):

"""Fine-tune Mamba model"""

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

tokenizer.pad_token = tokenizer.eos_token

Load model

model = MambaLMHeadModel.from_pretrained(

model_name,

device=device,

dtype=torch.bfloat16

)

Dataset

dataset = TextDataset(texts, tokenizer)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Optimizer (AdamW recommended for Mamba)

optimizer = torch.optim.AdamW(

model.parameters(),

lr=learning_rate,

weight_decay=0.1

)

Scheduler

total_steps = len(dataloader) * num_epochs

scheduler = get_linear_schedule_with_warmup(

optimizer,

num_warmup_steps=total_steps // 10,

num_training_steps=total_steps

)

Training loop

model.train()

for epoch in range(num_epochs):

total_loss = 0

for batch_idx, batch in enumerate(dataloader):

input_ids = batch['input_ids'].to(device)

Language modeling: next token prediction

outputs = model(input_ids)

logits = outputs.logits

Shift for next token prediction

shift_logits = logits[..., :-1, :].contiguous()

shift_labels = input_ids[..., 1:].contiguous()

loss = nn.CrossEntropyLoss()(

shift_logits.view(-1, shift_logits.size(-1)),

shift_labels.view(-1)

)

optimizer.zero_grad()

loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()

scheduler.step()

total_loss += loss.item()

if batch_idx % 100 == 0:

print(f"Epoch {epoch+1}, Step {batch_idx}: Loss = {loss.item():.4f}")

avg_loss = total_loss / len(dataloader)

print(f"Epoch {epoch+1} complete: Avg Loss = {avg_loss:.4f}")

return model

Custom Mamba Model Configuration

from mamba_ssm import Mamba

from mamba_ssm.models.config_mamba import MambaConfig

Custom configuration

config = MambaConfig(

d_model=1024,

n_layer=48,

vocab_size=50280,

d_state=16,

d_conv=4,

expand=2,

dt_rank="auto",

dt_min=0.001,

dt_max=0.1,

dt_init="random",

dt_scale=1.0,

dt_init_floor=1e-4,

rms_norm=True,

residual_in_fp32=True,

fused_add_norm=True,

pad_vocab_size_multiple=8,

)

Create model (~1.4B parameters)

model = MambaLMHeadModel(config, device="cuda", dtype=torch.bfloat16)

print(f"Parameter count: {sum(p.numel() for p in model.parameters()):,}")

Conclusion: The Future of Mamba

Mamba has brought revolutionary changes to the field of deep learning sequence modeling. Summarizing the key contributions:

1. **Selective Mechanism**: SSM parameters that dynamically change based on input

2. **Hardware-Aware Design**: Efficient computation leveraging GPU memory hierarchy

3. **Dual Representation**: Parallelized during training, recurrent during inference

4. **Linear Complexity**: Linear computation and memory complexity with sequence length

There are still challenges to address:

- Somewhat weaker in-context learning compared to Transformers

- Needs validation at very large scales

- Clear demonstration of superiority over attention-based models

However, Mamba and SSM-family models will increasingly play an important role in areas where Transformers struggle — long sequence processing, real-time inference, and edge device deployment.

References

- Mamba paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023) — https://arxiv.org/abs/2312.00752

- Mamba 2 paper: "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" (Dao & Gu, 2024) — https://arxiv.org/abs/2405.21060

- S4 paper: "Efficiently Modeling Long Sequences with Structured State Spaces" (Gu et al., 2021) — https://arxiv.org/abs/2111.00396

- H3 paper: "Hungry Hungry Hippos: Towards Language Modeling with State Space Models" (Fu et al., 2022) — https://arxiv.org/abs/2208.04933

- Mamba GitHub: https://github.com/state-spaces/mamba

- Jamba paper: "Jamba: A Hybrid Transformer-Mamba Language Model" (AI21 Labs, 2024)

현재 단락 (1/569)

Since the landmark "Attention is All You Need" paper in 2017, Transformers have dominated virtually ...

작성 글자: 0원문 글자: 24,986작성 단락: 0/569