Skip to content

Split View: LLM 컨텍스트 윈도우 확장 기술 완벽 가이드: RoPE, ALiBi, YaRN부터 Ring Attention까지

|

LLM 컨텍스트 윈도우 확장 기술 완벽 가이드: RoPE, ALiBi, YaRN부터 Ring Attention까지

들어가며

GPT-2는 1024 토큰, GPT-3는 2048 토큰으로 시작했지만, 2026년 현재 Claude는 200K, Gemini는 2M 토큰의 컨텍스트 윈도우를 지원합니다. 이 폭발적인 확장은 어떻게 가능했을까요?

핵심은 **위치 인코딩(Positional Encoding)**의 진화에 있습니다. 이 글에서는 절대 위치 인코딩부터 최신 RoPE 확장 기법까지, 수학적 원리와 실전 코드를 함께 다룹니다.

위치 인코딩의 진화

1. 절대 위치 인코딩 (GPT 시대)

import torch
import torch.nn as nn

class AbsolutePositionalEncoding(nn.Module):
    """학습 가능한 절대 위치 임베딩"""
    def __init__(self, max_len, d_model):
        super().__init__()
        # 학습 가능한 파라미터 - max_len이 고정
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.pe(positions)

한계: max_len을 넘어서는 위치에 대한 일반화가 불가능합니다.

2. Sinusoidal 위치 인코딩 (Transformer 원논문)

import math

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

수학적으로 어떤 길이의 상대 위치도 선형 변환으로 표현 가능하지만, 실제로는 학습 범위를 벗어나면 성능이 급격히 저하됩니다.

RoPE (Rotary Position Embedding)

Su et al. (2021)이 제안한 RoPE는 현재 대부분의 LLM(LLaMA, Qwen, DeepSeek 등)이 사용하는 위치 인코딩입니다.

수학적 원리

RoPE는 토큰의 위치를 복소 평면의 회전으로 인코딩합니다.

2D 차원 쌍 (q_{2i}, q_{2i+1})에 대해:

RoPE(q, m) = q * e^{im*theta_i}

여기서 theta_i = base^{-2i/d}, 기본 base=10000입니다.

import torch

def precompute_freqs_cis(dim, max_seq_len, base=10000.0):
    """RoPE 주파수 사전 계산"""
    # theta_i = base^(-2i/d)
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    # 위치별 각도: m * theta_i
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)  # (max_seq_len, dim/2)
    # 복소수로 변환: e^{i*m*theta}
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """쿼리와 키에 RoPE 적용"""
    # 실수 텐서를 복소수로 변환
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # 회전 적용
    freqs_cis = freqs_cis[:xq.shape[1]]
    xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

RoPE의 핵심 성질

상대 위치 의존성: 위치 m의 쿼리와 위치 n의 키의 내적은 (m-n)에만 의존합니다.

# 증명:
# <RoPE(q,m), RoPE(k,n)>
# = <q*e^{im*theta}, k*e^{in*theta}>
# = <q, k> * e^{i(m-n)*theta}  <- 상대 위치 (m-n)에만 의존!

컨텍스트 확장 기법들

1. Position Interpolation (PI)

Meta에서 제안한 가장 단순한 방법: 위치 인덱스를 스케일 다운합니다.

def position_interpolation(freqs_cis, original_max_len, target_max_len):
    """위치를 선형 보간하여 확장"""
    scale = original_max_len / target_max_len
    t = torch.arange(target_max_len) * scale  # 위치를 압축
    # 나머지는 동일
    freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# 예: 4K -> 32K 확장
# 위치 0~31999를 0~3999 범위로 매핑

문제점: 가까운 위치의 해상도가 떨어집니다. 인접 토큰을 구분하기 어려워집니다.

2. NTK-aware Scaling

base 값을 조정하여 고주파는 유지하고 저주파만 확장합니다.

def ntk_aware_scaling(dim, max_seq_len, base=10000.0, scale_factor=4):
    """NTK-aware RoPE: base를 조정하여 고주파 보존"""
    # base를 scale_factor에 비례하여 증가
    new_base = base * (scale_factor ** (dim / (dim - 2)))

    freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# 4K -> 16K (scale_factor=4)
# base: 10000 -> ~161489

3. NTK-by-Parts (Dynamic NTK)

차원별로 다른 스케일링을 적용합니다. 고주파 차원(가까운 위치 해상도)은 거의 건드리지 않고, 저주파 차원(먼 위치)만 확장합니다.

def ntk_by_parts(dim, max_seq_len, base=10000.0, original_max_len=4096,
                 target_max_len=32768, beta_fast=32, beta_slow=1):
    """차원별 차등 스케일링"""
    scale = target_max_len / original_max_len

    # 원래 주파수
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    # 각 주파수의 파장
    wavelengths = 2 * math.pi / freqs

    # 보간 비율 계산 (ramp function)
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    # 고주파(ramp=0): 원본 유지, 저주파(ramp=1): 완전 보간
    scaled_freqs = freqs / scale  # PI 방식
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    return torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

4. YaRN (Yet another RoPE extensioN)

NTK-by-Parts에 Attention Temperature Scaling을 추가한 방법입니다. DeepSeek, Qwen, LLaMA 등 대부분의 최신 LLM이 YaRN을 채택합니다.

def yarn_rope(dim, max_seq_len, base=10000.0, original_max_len=4096,
              target_max_len=131072, beta_fast=32, beta_slow=1):
    """YaRN: NTK-by-Parts + Temperature Scaling"""
    scale = target_max_len / original_max_len

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    wavelengths = 2 * math.pi / freqs
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    scaled_freqs = freqs / scale
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    # 핵심: Attention Temperature
    # sqrt(s) 보정으로 attention score 분포 보존
    temperature = 0.1 * math.log(scale) + 1.0

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

    return freqs_cis, temperature

# Attention 계산 시:
# attn_weights = (Q @ K.T) / (sqrt(d) * temperature)

YaRN의 효율성: 원래 학습 데이터의 0.1%만으로 fine-tuning하여 컨텍스트를 32x 확장할 수 있습니다.

ALiBi (Attention with Linear Biases)

RoPE와 다른 접근법으로, 위치 인코딩 없이 attention score에 선형 바이어스를 더합니다.

def alibi_bias(num_heads, max_seq_len):
    """ALiBi: attention에 거리 기반 선형 페널티"""
    # 헤드별 기울기
    slopes = torch.pow(2, -torch.arange(1, num_heads + 1) * 8.0 / num_heads)

    # 거리 행렬
    positions = torch.arange(max_seq_len)
    distances = positions.unsqueeze(0) - positions.unsqueeze(1)  # (L, L)
    distances = distances.abs().neg()  # 음수 거리

    # 헤드별 바이어스
    biases = slopes.unsqueeze(1).unsqueeze(2) * distances.unsqueeze(0)
    return biases  # (num_heads, L, L)

# Attention 계산:
# attn = softmax(Q @ K.T / sqrt(d) + alibi_bias)

장점: 추가 학습 없이 학습 길이를 넘어서는 외삽(extrapolation)이 가능합니다. 단점: 긴 거리의 토큰에 과도한 페널티를 부여하여 장거리 의존성이 약합니다.

Ring Attention — 하드웨어 수준 확장

Ring Attention은 시퀀스를 여러 GPU에 분산하여 메모리 한계를 넘어서는 컨텍스트를 처리합니다.

# Ring Attention 의사 코드
def ring_attention(Q, K, V, num_devices):
    """
    시퀀스를 num_devices개로 분할하여 분산 처리
    각 디바이스는 자신의 Q 청크와 순환하는 KV 청크로 attention 계산
    """
    seq_len = Q.shape[1]
    chunk_size = seq_len // num_devices

    # 각 디바이스의 Q 청크
    Q_local = Q[:, rank * chunk_size:(rank + 1) * chunk_size]

    # KV를 링 형태로 순환
    K_local = K[:, rank * chunk_size:(rank + 1) * chunk_size]
    V_local = V[:, rank * chunk_size:(rank + 1) * chunk_size]

    output = torch.zeros_like(Q_local)
    log_sum_exp = torch.full_like(Q_local[:, :, :1], float('-inf'))

    for step in range(num_devices):
        # 현재 KV 청크와 attention 계산
        attn = Q_local @ K_local.transpose(-2, -1) / math.sqrt(d)
        # Online softmax로 누적
        output, log_sum_exp = online_softmax_update(
            output, log_sum_exp, attn, V_local
        )
        # KV를 다음 디바이스에서 받아옴 (ring communication)
        K_local = ring_send_recv(K_local)
        V_local = ring_send_recv(V_local)

    return output

Ring Attention으로 Gemini는 2M 토큰 컨텍스트를 구현했습니다.

실전: HuggingFace에서 YaRN 적용

from transformers import AutoModelForCausalLM, AutoTokenizer

# Qwen2.5-7B (기본 32K -> 131K 확장)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    rope_scaling={
        "type": "yarn",
        "factor": 4.0,
        "original_max_position_embeddings": 32768,
    }
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")

# 긴 문서 처리
long_text = open("long_document.txt").read()
inputs = tokenizer(long_text, return_tensors="pt", truncation=True, max_length=131072)
outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))

성능 비교 요약

기법학습 필요확장 비율장거리 성능사용 모델
PIFine-tune4-8x보통초기 LLaMA2
NTK-aware없음/소량4-8x좋음Code Llama
YaRN소량 (0.1%)4-32x매우 좋음Qwen, DeepSeek, LLaMA3
ALiBi없음무제한(이론)약함MPT, BLOOM
Ring Attention없음GPU 수에 비례매우 좋음Gemini

✅ 퀴즈: LLM 컨텍스트 확장 이해도 점검 (8문제)

Q1. RoPE가 절대 위치 인코딩 대비 갖는 핵심 장점은?

상대 위치 의존성: 두 토큰의 attention은 절대 위치가 아닌 상대 거리(m-n)에만 의존합니다.

Q2. Position Interpolation의 단점은?

위치를 압축하므로 인접 토큰 간의 해상도가 떨어져 가까운 토큰을 구분하기 어려워집니다.

Q3. NTK-by-Parts가 PI보다 나은 이유는?

고주파 차원(가까운 위치 해상도)은 유지하고 저주파 차원(먼 위치)만 선택적으로 확장합니다.

Q4. YaRN에서 Temperature Scaling의 역할은?

컨텍스트 확장 시 attention score 분포가 변하는 것을 보정하여 softmax의 엔트로피를 유지합니다.

Q5. ALiBi의 작동 원리를 한 줄로 설명하면?

attention score에 토큰 거리에 비례하는 음의 선형 바이어스를 더하여 먼 토큰에 페널티를 부여합니다.

Q6. Ring Attention이 메모리 한계를 극복하는 방법은?

시퀀스를 여러 GPU에 분산하고, KV를 링 형태로 순환시키며 online softmax로 누적 계산합니다.

Q7. YaRN으로 32K를 131K로 확장할 때 필요한 학습 데이터량은?

원래 학습 데이터의 약 0.1% (약 400M 토큰)만으로 충분합니다.

Q8. RoPE의 base 값을 키우면 어떤 효과가 있나요?

주파수가 전체적으로 낮아져 더 먼 위치까지 구분할 수 있지만, 가까운 위치의 해상도가 감소합니다.

Complete Guide to LLM Context Window Extension: From RoPE, ALiBi, and YaRN to Ring Attention

Introduction

GPT-2 started with 1024 tokens and GPT-3 with 2048 tokens, but as of 2026, Claude supports 200K and Gemini supports 2M token context windows. How was this explosive expansion possible?

The key lies in the evolution of Positional Encoding. This article covers the mathematical principles and practical code from absolute positional encoding to the latest RoPE extension techniques.

The Evolution of Positional Encoding

1. Absolute Positional Encoding (GPT Era)

import torch
import torch.nn as nn

class AbsolutePositionalEncoding(nn.Module):
    """Learnable absolute position embeddings"""
    def __init__(self, max_len, d_model):
        super().__init__()
        # Learnable parameter - max_len is fixed
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.pe(positions)

Limitation: It cannot generalize to positions beyond max_len.

2. Sinusoidal Positional Encoding (Original Transformer Paper)

import math

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Mathematically, any relative position of any length can be represented through linear transformations, but in practice, performance degrades sharply when exceeding the training range.

RoPE (Rotary Position Embedding)

Proposed by Su et al. (2021), RoPE is the positional encoding used by most current LLMs (LLaMA, Qwen, DeepSeek, etc.).

Mathematical Principles

RoPE encodes token positions as rotations in the complex plane.

For a 2D dimension pair (q_{2i}, q_{2i+1}):

RoPE(q, m) = q * e^{im*theta_i}

Where theta_i = base^{-2i/d}, with a default base of 10000.

import torch

def precompute_freqs_cis(dim, max_seq_len, base=10000.0):
    """Precompute RoPE frequencies"""
    # theta_i = base^(-2i/d)
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    # Angle per position: m * theta_i
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)  # (max_seq_len, dim/2)
    # Convert to complex: e^{i*m*theta}
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """Apply RoPE to queries and keys"""
    # Convert real tensors to complex
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # Apply rotation
    freqs_cis = freqs_cis[:xq.shape[1]]
    xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

Key Properties of RoPE

Relative position dependency: The dot product between a query at position m and a key at position n depends only on (m-n).

# Proof:
# <RoPE(q,m), RoPE(k,n)>
# = <q*e^{im*theta}, k*e^{in*theta}>
# = <q, k> * e^{i(m-n)*theta}  <- depends only on relative position (m-n)!

Context Extension Techniques

1. Position Interpolation (PI)

The simplest method proposed by Meta: scale down the position indices.

def position_interpolation(freqs_cis, original_max_len, target_max_len):
    """Extend by linearly interpolating positions"""
    scale = original_max_len / target_max_len
    t = torch.arange(target_max_len) * scale  # Compress positions
    # The rest is the same
    freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# Example: 4K -> 32K extension
# Map positions 0-31999 to the 0-3999 range

Problem: Resolution for nearby positions degrades. It becomes difficult to distinguish adjacent tokens.

2. NTK-aware Scaling

Adjusts the base value to preserve high frequencies while extending only low frequencies.

def ntk_aware_scaling(dim, max_seq_len, base=10000.0, scale_factor=4):
    """NTK-aware RoPE: Adjust base to preserve high frequencies"""
    # Increase base proportionally to scale_factor
    new_base = base * (scale_factor ** (dim / (dim - 2)))

    freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# 4K -> 16K (scale_factor=4)
# base: 10000 -> ~161489

3. NTK-by-Parts (Dynamic NTK)

Applies different scaling per dimension. High-frequency dimensions (resolution for nearby positions) are barely touched, while only low-frequency dimensions (distant positions) are extended.

def ntk_by_parts(dim, max_seq_len, base=10000.0, original_max_len=4096,
                 target_max_len=32768, beta_fast=32, beta_slow=1):
    """Per-dimension differential scaling"""
    scale = target_max_len / original_max_len

    # Original frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    # Wavelength of each frequency
    wavelengths = 2 * math.pi / freqs

    # Calculate interpolation ratio (ramp function)
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    # High frequency (ramp=0): keep original, Low frequency (ramp=1): full interpolation
    scaled_freqs = freqs / scale  # PI method
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    return torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

4. YaRN (Yet another RoPE extensioN)

A method that adds Attention Temperature Scaling on top of NTK-by-Parts. Most state-of-the-art LLMs including DeepSeek, Qwen, and LLaMA adopt YaRN.

def yarn_rope(dim, max_seq_len, base=10000.0, original_max_len=4096,
              target_max_len=131072, beta_fast=32, beta_slow=1):
    """YaRN: NTK-by-Parts + Temperature Scaling"""
    scale = target_max_len / original_max_len

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    wavelengths = 2 * math.pi / freqs
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    scaled_freqs = freqs / scale
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    # Key: Attention Temperature
    # sqrt(s) correction to preserve attention score distribution
    temperature = 0.1 * math.log(scale) + 1.0

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

    return freqs_cis, temperature

# During attention computation:
# attn_weights = (Q @ K.T) / (sqrt(d) * temperature)

YaRN's efficiency: The context can be extended by 32x with fine-tuning on only 0.1% of the original training data.

ALiBi (Attention with Linear Biases)

A different approach from RoPE that adds linear biases to attention scores without positional encoding.

def alibi_bias(num_heads, max_seq_len):
    """ALiBi: linear distance-based penalty on attention"""
    # Per-head slopes
    slopes = torch.pow(2, -torch.arange(1, num_heads + 1) * 8.0 / num_heads)

    # Distance matrix
    positions = torch.arange(max_seq_len)
    distances = positions.unsqueeze(0) - positions.unsqueeze(1)  # (L, L)
    distances = distances.abs().neg()  # Negative distances

    # Per-head biases
    biases = slopes.unsqueeze(1).unsqueeze(2) * distances.unsqueeze(0)
    return biases  # (num_heads, L, L)

# Attention computation:
# attn = softmax(Q @ K.T / sqrt(d) + alibi_bias)

Advantage: Extrapolation beyond the training length is possible without additional training. Disadvantage: Excessive penalties for distant tokens weaken long-range dependencies.

Ring Attention — Hardware-Level Extension

Ring Attention distributes sequences across multiple GPUs to handle context beyond memory limits.

# Ring Attention pseudocode
def ring_attention(Q, K, V, num_devices):
    """
    Split the sequence into num_devices chunks for distributed processing.
    Each device computes attention with its own Q chunk and circulating KV chunks.
    """
    seq_len = Q.shape[1]
    chunk_size = seq_len // num_devices

    # Q chunk for each device
    Q_local = Q[:, rank * chunk_size:(rank + 1) * chunk_size]

    # KV circulates in a ring pattern
    K_local = K[:, rank * chunk_size:(rank + 1) * chunk_size]
    V_local = V[:, rank * chunk_size:(rank + 1) * chunk_size]

    output = torch.zeros_like(Q_local)
    log_sum_exp = torch.full_like(Q_local[:, :, :1], float('-inf'))

    for step in range(num_devices):
        # Compute attention with current KV chunk
        attn = Q_local @ K_local.transpose(-2, -1) / math.sqrt(d)
        # Accumulate with online softmax
        output, log_sum_exp = online_softmax_update(
            output, log_sum_exp, attn, V_local
        )
        # Receive KV from the next device (ring communication)
        K_local = ring_send_recv(K_local)
        V_local = ring_send_recv(V_local)

    return output

Using Ring Attention, Gemini achieved a 2M token context window.

Practical: Applying YaRN with HuggingFace

from transformers import AutoModelForCausalLM, AutoTokenizer

# Qwen2.5-7B (default 32K -> 131K extension)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    rope_scaling={
        "type": "yarn",
        "factor": 4.0,
        "original_max_position_embeddings": 32768,
    }
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")

# Process long documents
long_text = open("long_document.txt").read()
inputs = tokenizer(long_text, return_tensors="pt", truncation=True, max_length=131072)
outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))

Performance Comparison Summary

TechniqueTraining RequiredExtension RatioLong-range PerformanceModels Used
PIFine-tune4-8xModerateEarly LLaMA2
NTK-awareNone/Minimal4-8xGoodCode Llama
YaRNMinimal (0.1%)4-32xVery GoodQwen, DeepSeek, LLaMA3
ALiBiNoneUnlimited (theory)WeakMPT, BLOOM
Ring AttentionNoneScales with GPUsVery GoodGemini

Quiz: LLM Context Extension Comprehension Check (8 Questions)

Q1. What is the key advantage of RoPE over absolute positional encoding?

Relative position dependency: attention between two tokens depends only on their relative distance (m-n), not on absolute positions.

Q2. What is the downside of Position Interpolation?

Since positions are compressed, the resolution between adjacent tokens decreases, making it harder to distinguish nearby tokens.

Q3. Why is NTK-by-Parts better than PI?

It preserves high-frequency dimensions (resolution for nearby positions) and selectively extends only low-frequency dimensions (distant positions).

Q4. What is the role of Temperature Scaling in YaRN?

It corrects the change in attention score distribution during context extension to maintain the entropy of the softmax.

Q5. Explain ALiBi's working principle in one sentence.

It penalizes distant tokens by adding a negative linear bias proportional to token distance to the attention scores.

Q6. How does Ring Attention overcome memory limitations?

It distributes the sequence across multiple GPUs, circulates KV in a ring pattern, and accumulates using online softmax.

Q7. How much training data is needed to extend 32K to 131K using YaRN?

Approximately 0.1% of the original training data (about 400M tokens) is sufficient.

Q8. What happens when you increase the base value in RoPE?

The frequencies decrease overall, allowing distinction of more distant positions, but the resolution for nearby positions decreases.

Quiz

Q1: What is the main topic covered in "Complete Guide to LLM Context Window Extension: From RoPE, ALiBi, and YaRN to Ring Attention"?

An in-depth analysis of techniques for extending LLM context windows from 512 to 2M tokens. Covers the mathematical principles of RoPE through NTK-aware scaling, YaRN, and Ring Attention with practical code examples.

Q2: What is The Evolution of Positional Encoding?
  1. Absolute Positional Encoding (GPT Era) Limitation: It cannot generalize to positions beyond max_len. 2. Sinusoidal Positional Encoding (Original Transformer Paper) Mathematically, any relative position of any length can be represented through linear transformations, but in pra...

Q3: Explain the core concept of RoPE (Rotary Position Embedding). Proposed by Su et al. (2021), RoPE is the positional encoding used by most current LLMs (LLaMA, Qwen, DeepSeek, etc.). Mathematical Principles RoPE encodes token positions as rotations in the complex plane.

Q4: What are the key aspects of Context Extension Techniques?
  1. Position Interpolation (PI) The simplest method proposed by Meta: scale down the position indices. Problem: Resolution for nearby positions degrades. It becomes difficult to distinguish adjacent tokens. 2.

Q5: How does ALiBi (Attention with Linear Biases) work?A different approach from RoPE that adds linear biases to attention scores without positional encoding. Advantage: Extrapolation beyond the training length is possible without additional training.