Skip to content
Published on

Complete Guide to LLM Context Window Extension: From RoPE, ALiBi, and YaRN to Ring Attention

Authors
  • Name
    Twitter

Introduction

GPT-2 started with 1024 tokens and GPT-3 with 2048 tokens, but as of 2026, Claude supports 200K and Gemini supports 2M token context windows. How was this explosive expansion possible?

The key lies in the evolution of Positional Encoding. This article covers the mathematical principles and practical code from absolute positional encoding to the latest RoPE extension techniques.

The Evolution of Positional Encoding

1. Absolute Positional Encoding (GPT Era)

import torch
import torch.nn as nn

class AbsolutePositionalEncoding(nn.Module):
    """Learnable absolute position embeddings"""
    def __init__(self, max_len, d_model):
        super().__init__()
        # Learnable parameter - max_len is fixed
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.pe(positions)

Limitation: It cannot generalize to positions beyond max_len.

2. Sinusoidal Positional Encoding (Original Transformer Paper)

import math

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Mathematically, any relative position of any length can be represented through linear transformations, but in practice, performance degrades sharply when exceeding the training range.

RoPE (Rotary Position Embedding)

Proposed by Su et al. (2021), RoPE is the positional encoding used by most current LLMs (LLaMA, Qwen, DeepSeek, etc.).

Mathematical Principles

RoPE encodes token positions as rotations in the complex plane.

For a 2D dimension pair (q_{2i}, q_{2i+1}):

RoPE(q, m) = q * e^{im*theta_i}

Where theta_i = base^{-2i/d}, with a default base of 10000.

import torch

def precompute_freqs_cis(dim, max_seq_len, base=10000.0):
    """Precompute RoPE frequencies"""
    # theta_i = base^(-2i/d)
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    # Angle per position: m * theta_i
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)  # (max_seq_len, dim/2)
    # Convert to complex: e^{i*m*theta}
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """Apply RoPE to queries and keys"""
    # Convert real tensors to complex
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # Apply rotation
    freqs_cis = freqs_cis[:xq.shape[1]]
    xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

Key Properties of RoPE

Relative position dependency: The dot product between a query at position m and a key at position n depends only on (m-n).

# Proof:
# <RoPE(q,m), RoPE(k,n)>
# = <q*e^{im*theta}, k*e^{in*theta}>
# = <q, k> * e^{i(m-n)*theta}  <- depends only on relative position (m-n)!

Context Extension Techniques

1. Position Interpolation (PI)

The simplest method proposed by Meta: scale down the position indices.

def position_interpolation(freqs_cis, original_max_len, target_max_len):
    """Extend by linearly interpolating positions"""
    scale = original_max_len / target_max_len
    t = torch.arange(target_max_len) * scale  # Compress positions
    # The rest is the same
    freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# Example: 4K -> 32K extension
# Map positions 0-31999 to the 0-3999 range

Problem: Resolution for nearby positions degrades. It becomes difficult to distinguish adjacent tokens.

2. NTK-aware Scaling

Adjusts the base value to preserve high frequencies while extending only low frequencies.

def ntk_aware_scaling(dim, max_seq_len, base=10000.0, scale_factor=4):
    """NTK-aware RoPE: Adjust base to preserve high frequencies"""
    # Increase base proportionally to scale_factor
    new_base = base * (scale_factor ** (dim / (dim - 2)))

    freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# 4K -> 16K (scale_factor=4)
# base: 10000 -> ~161489

3. NTK-by-Parts (Dynamic NTK)

Applies different scaling per dimension. High-frequency dimensions (resolution for nearby positions) are barely touched, while only low-frequency dimensions (distant positions) are extended.

def ntk_by_parts(dim, max_seq_len, base=10000.0, original_max_len=4096,
                 target_max_len=32768, beta_fast=32, beta_slow=1):
    """Per-dimension differential scaling"""
    scale = target_max_len / original_max_len

    # Original frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    # Wavelength of each frequency
    wavelengths = 2 * math.pi / freqs

    # Calculate interpolation ratio (ramp function)
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    # High frequency (ramp=0): keep original, Low frequency (ramp=1): full interpolation
    scaled_freqs = freqs / scale  # PI method
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    return torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

4. YaRN (Yet another RoPE extensioN)

A method that adds Attention Temperature Scaling on top of NTK-by-Parts. Most state-of-the-art LLMs including DeepSeek, Qwen, and LLaMA adopt YaRN.

def yarn_rope(dim, max_seq_len, base=10000.0, original_max_len=4096,
              target_max_len=131072, beta_fast=32, beta_slow=1):
    """YaRN: NTK-by-Parts + Temperature Scaling"""
    scale = target_max_len / original_max_len

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    wavelengths = 2 * math.pi / freqs
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    scaled_freqs = freqs / scale
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    # Key: Attention Temperature
    # sqrt(s) correction to preserve attention score distribution
    temperature = 0.1 * math.log(scale) + 1.0

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

    return freqs_cis, temperature

# During attention computation:
# attn_weights = (Q @ K.T) / (sqrt(d) * temperature)

YaRN's efficiency: The context can be extended by 32x with fine-tuning on only 0.1% of the original training data.

ALiBi (Attention with Linear Biases)

A different approach from RoPE that adds linear biases to attention scores without positional encoding.

def alibi_bias(num_heads, max_seq_len):
    """ALiBi: linear distance-based penalty on attention"""
    # Per-head slopes
    slopes = torch.pow(2, -torch.arange(1, num_heads + 1) * 8.0 / num_heads)

    # Distance matrix
    positions = torch.arange(max_seq_len)
    distances = positions.unsqueeze(0) - positions.unsqueeze(1)  # (L, L)
    distances = distances.abs().neg()  # Negative distances

    # Per-head biases
    biases = slopes.unsqueeze(1).unsqueeze(2) * distances.unsqueeze(0)
    return biases  # (num_heads, L, L)

# Attention computation:
# attn = softmax(Q @ K.T / sqrt(d) + alibi_bias)

Advantage: Extrapolation beyond the training length is possible without additional training. Disadvantage: Excessive penalties for distant tokens weaken long-range dependencies.

Ring Attention — Hardware-Level Extension

Ring Attention distributes sequences across multiple GPUs to handle context beyond memory limits.

# Ring Attention pseudocode
def ring_attention(Q, K, V, num_devices):
    """
    Split the sequence into num_devices chunks for distributed processing.
    Each device computes attention with its own Q chunk and circulating KV chunks.
    """
    seq_len = Q.shape[1]
    chunk_size = seq_len // num_devices

    # Q chunk for each device
    Q_local = Q[:, rank * chunk_size:(rank + 1) * chunk_size]

    # KV circulates in a ring pattern
    K_local = K[:, rank * chunk_size:(rank + 1) * chunk_size]
    V_local = V[:, rank * chunk_size:(rank + 1) * chunk_size]

    output = torch.zeros_like(Q_local)
    log_sum_exp = torch.full_like(Q_local[:, :, :1], float('-inf'))

    for step in range(num_devices):
        # Compute attention with current KV chunk
        attn = Q_local @ K_local.transpose(-2, -1) / math.sqrt(d)
        # Accumulate with online softmax
        output, log_sum_exp = online_softmax_update(
            output, log_sum_exp, attn, V_local
        )
        # Receive KV from the next device (ring communication)
        K_local = ring_send_recv(K_local)
        V_local = ring_send_recv(V_local)

    return output

Using Ring Attention, Gemini achieved a 2M token context window.

Practical: Applying YaRN with HuggingFace

from transformers import AutoModelForCausalLM, AutoTokenizer

# Qwen2.5-7B (default 32K -> 131K extension)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    rope_scaling={
        "type": "yarn",
        "factor": 4.0,
        "original_max_position_embeddings": 32768,
    }
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")

# Process long documents
long_text = open("long_document.txt").read()
inputs = tokenizer(long_text, return_tensors="pt", truncation=True, max_length=131072)
outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))

Performance Comparison Summary

TechniqueTraining RequiredExtension RatioLong-range PerformanceModels Used
PIFine-tune4-8xModerateEarly LLaMA2
NTK-awareNone/Minimal4-8xGoodCode Llama
YaRNMinimal (0.1%)4-32xVery GoodQwen, DeepSeek, LLaMA3
ALiBiNoneUnlimited (theory)WeakMPT, BLOOM
Ring AttentionNoneScales with GPUsVery GoodGemini

Quiz: LLM Context Extension Comprehension Check (8 Questions)

Q1. What is the key advantage of RoPE over absolute positional encoding?

Relative position dependency: attention between two tokens depends only on their relative distance (m-n), not on absolute positions.

Q2. What is the downside of Position Interpolation?

Since positions are compressed, the resolution between adjacent tokens decreases, making it harder to distinguish nearby tokens.

Q3. Why is NTK-by-Parts better than PI?

It preserves high-frequency dimensions (resolution for nearby positions) and selectively extends only low-frequency dimensions (distant positions).

Q4. What is the role of Temperature Scaling in YaRN?

It corrects the change in attention score distribution during context extension to maintain the entropy of the softmax.

Q5. Explain ALiBi's working principle in one sentence.

It penalizes distant tokens by adding a negative linear bias proportional to token distance to the attention scores.

Q6. How does Ring Attention overcome memory limitations?

It distributes the sequence across multiple GPUs, circulates KV in a ring pattern, and accumulates using online softmax.

Q7. How much training data is needed to extend 32K to 131K using YaRN?

Approximately 0.1% of the original training data (about 400M tokens) is sufficient.

Q8. What happens when you increase the base value in RoPE?

The frequencies decrease overall, allowing distinction of more distant positions, but the resolution for nearby positions decreases.