Skip to main content
Back to Blog

Context Extension: How LLMs Scale Beyond Training Length

A comprehensive deep dive into context extension techniques—how models trained on 4K tokens extrapolate to 128K+. Understand RoPE scaling, Position Interpolation, NTK-aware scaling, YaRN, and the mathematics of long-context LLMs.

5 min read
Share:

The Context Length Problem

Modern LLMs are trained with a fixed context length—typically 2K, 4K, or 8K tokens. But applications increasingly demand longer contexts: entire codebases, full documents, long conversations. How do you use a model at context lengths far beyond what it was trained on?

The naive approach fails catastrophically. A model trained on 4K tokens produces garbage at 8K tokens. The problem isn't memory or compute—it's that the model has never seen positions beyond 4K and doesn't know what to do with them.

Context extension techniques solve this problem, enabling models to extrapolate to 2×, 4×, or even 32× their training length with minimal or no additional training. This post explains how they work, from the fundamentals of positional encoding to state-of-the-art methods like YaRN.

Understanding context extension is essential because it's how every long-context model works. GPT-4's 128K context, Claude's 200K context, Gemini 1.5's 1M context, and now Grok-4's 2M context, Gemini 2.0 Pro's 2M context, and Llama 4 Scout's industry-leading 10M context—all use variations of these techniques. You can apply them to extend your own models.


Part I: Why Position Matters

The Position Problem in Attention

Attention is permutation-equivariant: if you shuffle input tokens, the output shuffles the same way. This is a problem because word order matters. "The cat sat on the mat" means something different from "mat the on sat cat the."

Positional encoding solves this by injecting position information into the model. But how you encode position determines how the model handles unseen positions.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    WHY POSITION ENCODING MATTERS                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  WITHOUT POSITION ENCODING:                                              │
│  ──────────────────────────                                              │
│                                                                          │
│  Input:  "The cat sat"                                                  │
│  Tokens: [The, cat, sat]                                                │
│                                                                          │
│  Attention sees:                                                        │
│  • Token embeddings only                                               │
│  • No information about order                                          │
│  • [The, cat, sat] = [sat, The, cat] = [cat, sat, The]               │
│                                                                          │
│  The model cannot distinguish word order!                              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WITH POSITION ENCODING:                                                 │
│  ───────────────────────                                                 │
│                                                                          │
│  Position info added to embeddings:                                    │
│                                                                          │
│  Token:     [The,      cat,      sat     ]                             │
│  Position:  [pos_0,    pos_1,    pos_2   ]                             │
│  Combined:  [The+p0,   cat+p1,   sat+p2  ]                             │
│                                                                          │
│  Now the model knows:                                                  │
│  • "The" is at position 0                                              │
│  • "cat" is at position 1                                              │
│  • "sat" is at position 2                                              │
│                                                                          │
│  Different orderings produce different representations.                │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  THE CONTEXT LENGTH PROBLEM:                                             │
│  ───────────────────────────                                             │
│                                                                          │
│  Model trained with max_position = 4096                                │
│                                                                          │
│  During training, model sees:                                          │
│  pos_0, pos_1, pos_2, ... pos_4095  ✓                                 │
│                                                                          │
│  At inference, if sequence has 8000 tokens:                           │
│  pos_0, pos_1, ... pos_4095, pos_4096, ... pos_7999                   │
│                              │                                         │
│                              └── NEVER SEEN IN TRAINING!              │
│                                                                          │
│  The model doesn't know what to do with positions > 4095.             │
│  Output quality degrades catastrophically.                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Types of Position Encoding

Different position encoding schemes have different extrapolation properties:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    POSITION ENCODING APPROACHES                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1. ABSOLUTE LEARNED EMBEDDINGS (GPT-2, BERT):                          │
│  ──────────────────────────────────────────────                          │
│                                                                          │
│  pos_embedding = nn.Embedding(max_position, hidden_size)               │
│                                                                          │
│  position 0 → learned vector [0.1, -0.2, 0.5, ...]                   │
│  position 1 → learned vector [0.3, 0.1, -0.1, ...]                   │
│  ...                                                                   │
│  position 4095 → learned vector [...]                                 │
│  position 4096 → NO EMBEDDING EXISTS! Cannot extrapolate.             │
│                                                                          │
│  Extrapolation: IMPOSSIBLE                                             │
│  The embedding table simply doesn't have entries beyond training.     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  2. SINUSOIDAL EMBEDDINGS (Original Transformer):                       │
│  ────────────────────────────────────────────────                        │
│                                                                          │
│  PE(pos, 2i) = sin(pos / 10000^(2i/d))                                │
│  PE(pos, 2i+1) = cos(pos / 10000^(2i/d))                              │
│                                                                          │
│  Mathematical formula—can compute for any position!                   │
│                                                                          │
│  Extrapolation: THEORETICALLY POSSIBLE                                 │
│  But in practice, model hasn't learned to use unseen positions.       │
│  Quality degrades beyond training length.                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  3. ROTARY POSITION EMBEDDING (RoPE) - Modern Standard:                │
│  ───────────────────────────────────────────────────────                 │
│                                                                          │
│  Instead of adding position to embeddings, ROTATE the vectors.        │
│  Rotation angle depends on position.                                  │
│                                                                          │
│  q_rotated = q × rotation_matrix(position)                            │
│  k_rotated = k × rotation_matrix(position)                            │
│                                                                          │
│  Key property: attention score depends on RELATIVE position!          │
│  score(q_i, k_j) depends only on (i - j), not absolute i or j        │
│                                                                          │
│  Extrapolation: PARTIALLY POSSIBLE with modifications                 │
│  RoPE has special structure that enables scaling techniques.          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  4. ALiBi (Attention with Linear Biases):                              │
│  ─────────────────────────────────────────                               │
│                                                                          │
│  No position in embeddings at all.                                    │
│  Instead, bias attention scores by distance:                          │
│                                                                          │
│  attention_score -= m × |i - j|                                       │
│                                                                          │
│  Closer tokens → higher attention                                     │
│  Farther tokens → lower attention                                     │
│                                                                          │
│  Extrapolation: NATURALLY SUPPORTS                                     │
│  The bias formula works for any distance.                             │
│  But maximum effective context is still limited.                      │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part II: RoPE Deep Dive

How Rotary Position Embedding Works

RoPE has become the dominant position encoding for modern LLMs (Llama, Mistral, Qwen, etc.) because it enables relative position modeling and can be scaled for context extension.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ROTARY POSITION EMBEDDING (RoPE)                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE KEY INSIGHT:                                                        │
│  ────────────────                                                        │
│                                                                          │
│  Encode position by ROTATING query and key vectors.                   │
│  When we compute attention (q · k), the rotation angles subtract,     │
│  making the result depend only on RELATIVE position.                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ROTATION IN 2D (simplified):                                           │
│  ─────────────────────────────                                           │
│                                                                          │
│  Consider a 2D vector [x, y].                                         │
│  Rotation by angle θ:                                                  │
│                                                                          │
│  [x']   [cos(θ)  -sin(θ)] [x]                                         │
│  [y'] = [sin(θ)   cos(θ)] [y]                                         │
│                                                                          │
│  x' = x·cos(θ) - y·sin(θ)                                             │
│  y' = x·sin(θ) + y·cos(θ)                                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ROPE APPLIES ROTATION TO Q AND K:                                      │
│  ──────────────────────────────────                                      │
│                                                                          │
│  For query at position m:    q_rot = rotate(q, m × θ)                 │
│  For key at position n:      k_rot = rotate(k, n × θ)                 │
│                                                                          │
│  Attention score: q_rot · k_rot                                       │
│                                                                          │
│  Due to rotation properties:                                           │
│  q_rot · k_rot = q · R(m×θ)ᵀ × R(n×θ) × k                            │
│                = q · R((m-n)×θ) × k                                   │
│                = depends only on (m - n)!                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  HIGHER DIMENSIONS:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  Hidden dimension d = 4096 → split into d/2 = 2048 pairs             │
│  Each pair rotated by different frequency:                            │
│                                                                          │
│  Pair 0: θ₀ = 1                   (high frequency, local patterns)    │
│  Pair 1: θ₁ = 1/10000^(2/d)       (slightly lower)                   │
│  Pair 2: θ₂ = 1/10000^(4/d)       (lower still)                      │
│  ...                                                                   │
│  Pair 2047: θ₂₀₄₇ = 1/10000^(4094/d)  (very low frequency, global)  │
│                                                                          │
│  Higher dimensions → lower frequencies → longer-range patterns        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  IMPLEMENTATION:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  def apply_rope(x, position, base=10000):                             │
│      """                                                               │
│      x: (batch, seq_len, num_heads, head_dim)                        │
│      position: position indices                                       │
│      """                                                               │
│      head_dim = x.shape[-1]                                           │
│                                                                          │
│      # Compute frequencies for each dimension pair                    │
│      # θᵢ = 1 / base^(2i / head_dim)                                 │
│      dim_indices = torch.arange(0, head_dim, 2)                      │
│      freqs = 1.0 / (base ** (dim_indices / head_dim))                │
│                                                                          │
│      # Compute rotation angles: position × frequency                  │
│      angles = position.unsqueeze(-1) * freqs  # (seq_len, head_dim/2)│
│                                                                          │
│      # Create rotation components                                      │
│      cos = torch.cos(angles)                                          │
│      sin = torch.sin(angles)                                          │
│                                                                          │
│      # Split x into pairs and rotate                                  │
│      x1, x2 = x[..., ::2], x[..., 1::2]  # Even, odd dimensions     │
│      x_rot1 = x1 * cos - x2 * sin                                    │
│      x_rot2 = x1 * sin + x2 * cos                                    │
│                                                                          │
│      # Interleave back                                                │
│      x_rot = torch.stack([x_rot1, x_rot2], dim=-1).flatten(-2)      │
│      return x_rot                                                      │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

The Extrapolation Problem with RoPE

RoPE's rotation angles are position × frequency. At unseen positions, angles become extreme:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ROPE EXTRAPOLATION FAILURE                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  TRAINING CONTEXT: 4096 tokens                                          │
│  INFERENCE REQUEST: 8000 tokens                                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  HIGH-FREQUENCY DIMENSIONS:                                              │
│  ──────────────────────────                                              │
│                                                                          │
│  θ₀ = 1 (base frequency)                                               │
│  Position 0: angle = 0                                                 │
│  Position 4095: angle = 4095 radians ≈ 651 full rotations            │
│  Position 8000: angle = 8000 radians ≈ 1273 full rotations           │
│                                                                          │
│  High frequencies wrap many times—this is fine.                       │
│  cos(4095) and cos(8000) are both valid values.                      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  LOW-FREQUENCY DIMENSIONS (THE PROBLEM):                                │
│  ────────────────────────────────────────                                │
│                                                                          │
│  θ₂₀₄₇ ≈ 0.0001 (lowest frequency)                                    │
│  Position 0: angle = 0                                                 │
│  Position 4095: angle ≈ 0.41 radians (partial rotation)              │
│  Position 8000: angle ≈ 0.80 radians                                  │
│                                                                          │
│  During training, model ONLY saw angles 0 to 0.41.                    │
│  At position 8000, angle is 0.80—NEVER SEEN!                          │
│                                                                          │
│  VISUALIZATION:                                                         │
│  ──────────────                                                          │
│                                                                          │
│       Training range        │ Extrapolation region                     │
│       (0 to 0.41 rad)       │ (0.41 to 0.80 rad)                      │
│  ────────────────────────────┼─────────────────────                     │
│  0         0.41             │             0.80                         │
│  │▓▓▓▓▓▓▓▓▓▓│               │              │                          │
│  │ SEEN     │               │  UNSEEN      │                          │
│                                                                          │
│  The model learned patterns for angles 0-0.41.                        │
│  It has no idea what to do with 0.41-0.80!                           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY THIS MATTERS:                                                       │
│  ─────────────────                                                       │
│                                                                          │
│  Low-frequency dimensions encode LONG-RANGE patterns:                 │
│  • Which paragraph is this token in?                                  │
│  • Beginning, middle, or end of document?                             │
│  • Relationship to distant context                                    │
│                                                                          │
│  When these dimensions have unseen angles:                            │
│  • Model can't encode position correctly                              │
│  • Long-range dependencies break                                      │
│  • Quality degrades severely                                          │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part III: Position Interpolation

The First Scaling Technique

Position Interpolation (PI) was the first successful RoPE scaling technique. The idea is simple: instead of extrapolating to unseen positions, interpolate within the training range.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    POSITION INTERPOLATION (PI)                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CORE IDEA:                                                          │
│  ──────────────                                                          │
│                                                                          │
│  Don't use position 8000 (unseen).                                    │
│  Instead, SCALE positions to fit within training range.               │
│                                                                          │
│  Training max: 4096                                                    │
│  Inference max: 8192                                                   │
│  Scale factor: 8192 / 4096 = 2                                        │
│                                                                          │
│  Position 8000 → 8000 / 2 = 4000 (within training range!)            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BEFORE (Extrapolation - FAILS):                                        │
│  ─────────────────────────────────                                       │
│                                                                          │
│  Position: 0    1000   2000   3000   4000   5000   6000   7000   8000 │
│  Angle:    0    1000   2000   3000   4000   5000   6000   7000   8000 │
│                                          │                             │
│                    Training range ───────┘                             │
│                                                                          │
│  Positions 4097-8000 produce unseen angles. Model fails.             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  AFTER (Interpolation - WORKS):                                         │
│  ────────────────────────────────                                        │
│                                                                          │
│  Position: 0    1000   2000   3000   4000   5000   6000   7000   8000 │
│  Scaled:   0    500    1000   1500   2000   2500   3000   3500   4000 │
│  Angle:    0    500    1000   1500   2000   2500   3000   3500   4000 │
│            │                                                    │     │
│            └────────────── All within training range ───────────┘     │
│                                                                          │
│  All positions map to seen angles. Model works!                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  IMPLEMENTATION:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  def apply_rope_with_scaling(x, position, base=10000, scale=1.0):     │
│      """                                                               │
│      scale > 1 means we're extending beyond training length.          │
│      Divide positions by scale to interpolate.                        │
│      """                                                               │
│      head_dim = x.shape[-1]                                           │
│                                                                          │
│      dim_indices = torch.arange(0, head_dim, 2)                      │
│      freqs = 1.0 / (base ** (dim_indices / head_dim))                │
│                                                                          │
│      # POSITION INTERPOLATION: divide positions by scale             │
│      scaled_positions = position / scale                              │
│                                                                          │
│      angles = scaled_positions.unsqueeze(-1) * freqs                 │
│                                                                          │
│      # Rest is same as standard RoPE                                  │
│      cos, sin = torch.cos(angles), torch.sin(angles)                 │
│      x1, x2 = x[..., ::2], x[..., 1::2]                              │
│      x_rot = torch.stack([                                            │
│          x1 * cos - x2 * sin,                                        │
│          x1 * sin + x2 * cos                                         │
│      ], dim=-1).flatten(-2)                                          │
│      return x_rot                                                      │
│                                                                          │
│  # For 4K→8K extension:                                               │
│  scale = 8192 / 4096  # = 2.0                                        │
│  output = apply_rope_with_scaling(x, positions, scale=scale)         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  THE TRADEOFF:                                                           │
│  ─────────────                                                           │
│                                                                          │
│  Position Interpolation works but has a cost:                         │
│                                                                          │
│  Adjacent positions are now CLOSER together:                          │
│  • Original: pos 100 and 101 differ by angle 1                       │
│  • Scaled (2×): pos 100 and 101 differ by angle 0.5                  │
│                                                                          │
│  The model has LESS resolution for distinguishing nearby tokens.     │
│  This hurts short-range pattern recognition.                          │
│                                                                          │
│  Solution: Fine-tune briefly to adapt.                                │
│  ~1000 steps of fine-tuning usually recovers quality.                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

PI Results and Limitations

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    POSITION INTERPOLATION RESULTS                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  LLAMA 7B EXTENSION (Original PI Paper):                                │
│  ────────────────────────────────────────                                │
│                                                                          │
│  Training length: 2048                                                  │
│  Extended to: 8192 (4× extension)                                      │
│  Fine-tuning: 1000 steps                                               │
│                                                                          │
│  Results:                                                               │
│  • Perplexity at 8192 tokens: Good (comparable to native)             │
│  • Perplexity at 2048 tokens: Slightly degraded                       │
│  • Long-context tasks: Strong performance                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  LIMITATIONS OF PI:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  1. SHORT-CONTEXT DEGRADATION                                           │
│     Uniform scaling hurts nearby-token resolution.                    │
│     Quality on short sequences drops.                                  │
│                                                                          │
│  2. REQUIRES FINE-TUNING                                                │
│     Without fine-tuning, quality is poor.                             │
│     Fine-tuning on long sequences is expensive.                       │
│                                                                          │
│  3. HIGH-FREQUENCY DIMENSIONS OVER-SCALED                              │
│     High frequencies (for local patterns) don't need scaling.         │
│     They already wrap many times—scaling just reduces resolution.    │
│                                                                          │
│  These limitations motivated NTK-aware scaling and YaRN.              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part IV: NTK-Aware Interpolation

The Frequency-Aware Insight

NTK-aware interpolation recognizes that different frequency dimensions need different treatment:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    NTK-AWARE INTERPOLATION                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE KEY INSIGHT:                                                        │
│  ────────────────                                                        │
│                                                                          │
│  Not all dimensions need the same scaling!                             │
│                                                                          │
│  HIGH-FREQUENCY dimensions (local patterns):                           │
│  • Already complete many rotations in training range                  │
│  • Extrapolation is FINE (angles wrap periodically)                   │
│  • DON'T scale these—preserve local resolution                        │
│                                                                          │
│  LOW-FREQUENCY dimensions (global patterns):                           │
│  • Only partial rotation in training range                            │
│  • Extrapolation FAILS (unseen angles)                                │
│  • MUST scale these to avoid unseen territory                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  POSITION INTERPOLATION (PI):                                           │
│  ─────────────────────────────                                           │
│                                                                          │
│  scales ALL dimensions equally by dividing position                   │
│                                                                          │
│  Dim 0 (high freq): θ × pos → θ × (pos/2)   ← Loses resolution!     │
│  Dim 2047 (low freq): θ × pos → θ × (pos/2) ← Needed                 │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  NTK-AWARE SCALING:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  Scales the BASE (not position), affecting dimensions differently     │
│                                                                          │
│  Original: freqs = 1 / base^(dim / head_dim)                         │
│  NTK:      freqs = 1 / (base × scale)^(dim / head_dim)               │
│                                                                          │
│  Effect on each dimension:                                             │
│                                                                          │
│  Dim 0 (high freq):                                                    │
│  Original: 1 / 10000^(0/d) = 1                                        │
│  NTK 2×:   1 / 20000^(0/d) = 1         ← UNCHANGED!                  │
│                                                                          │
│  Dim d/2 (mid freq):                                                   │
│  Original: 1 / 10000^(0.5) = 0.01                                    │
│  NTK 2×:   1 / 20000^(0.5) = 0.007     ← Slightly scaled             │
│                                                                          │
│  Dim d (low freq):                                                     │
│  Original: 1 / 10000^(1) = 0.0001                                    │
│  NTK 2×:   1 / 20000^(1) = 0.00005     ← HALVED (interpolated)      │
│                                                                          │
│  High dimensions scale MORE, low dimensions scale LESS!              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  IMPLEMENTATION:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  def apply_rope_ntk(x, position, base=10000, scale=1.0):              │
│      """                                                               │
│      NTK-aware scaling: modify base instead of position.              │
│      """                                                               │
│      head_dim = x.shape[-1]                                           │
│                                                                          │
│      # NTK: multiply base by scale                                    │
│      scaled_base = base * scale                                       │
│                                                                          │
│      dim_indices = torch.arange(0, head_dim, 2)                      │
│      freqs = 1.0 / (scaled_base ** (dim_indices / head_dim))         │
│                                                                          │
│      # Position NOT scaled                                            │
│      angles = position.unsqueeze(-1) * freqs                         │
│                                                                          │
│      cos, sin = torch.cos(angles), torch.sin(angles)                 │
│      x1, x2 = x[..., ::2], x[..., 1::2]                              │
│      x_rot = torch.stack([                                            │
│          x1 * cos - x2 * sin,                                        │
│          x1 * sin + x2 * cos                                         │
│      ], dim=-1).flatten(-2)                                          │
│      return x_rot                                                      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY "NTK"?                                                             │
│  ──────────                                                              │
│                                                                          │
│  Named after Neural Tangent Kernel theory.                            │
│  The insight came from analyzing how position encoding affects       │
│  the kernel (similarity function) of the network.                    │
│                                                                          │
│  Scaling base preserves the kernel structure better than              │
│  scaling positions.                                                    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

NTK Variants

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    NTK SCALING VARIANTS                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1. NTK-AWARE (Simple):                                                 │
│  ──────────────────────                                                  │
│  scaled_base = base × scale                                            │
│                                                                          │
│  For 4K→8K: base = 10000 × 2 = 20000                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  2. NTK-BY-PARTS (Dynamic):                                             │
│  ──────────────────────────                                              │
│                                                                          │
│  Different scaling for different dimension ranges:                    │
│                                                                          │
│  dims 0-15:      no scaling (high freq)                               │
│  dims 16-63:     partial scaling                                      │
│  dims 64+:       full interpolation (low freq)                        │
│                                                                          │
│  More fine-grained than simple NTK.                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  3. DYNAMIC NTK:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  Compute scale dynamically based on current sequence length:          │
│                                                                          │
│  if seq_len <= train_length:                                          │
│      scale = 1.0  (no scaling needed)                                 │
│  else:                                                                 │
│      scale = seq_len / train_length                                   │
│                                                                          │
│  Adapts automatically—short sequences get no scaling penalty.        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPARISON:                                                             │
│                                                                          │
│  Method          Short Context  Long Context  Fine-tune Needed        │
│  ───────────────────────────────────────────────────────────          │
│  PI              Degraded       Good          Yes                     │
│  NTK-Aware       Better         Good          Less                    │
│  Dynamic NTK     Preserved      Good          Minimal                 │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part V: YaRN - Yet another RoPE extensioN

The State of the Art

YaRN combines the best ideas from PI and NTK with additional improvements:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    YARN (Yet another RoPE extensioN)                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  YARN COMBINES THREE TECHNIQUES:                                         │
│  ────────────────────────────────                                        │
│                                                                          │
│  1. NTK-by-parts interpolation                                         │
│  2. Attention scaling (temperature adjustment)                         │
│  3. Fine-grained dimension handling                                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPONENT 1: DIMENSION PARTITIONING                                    │
│  ────────────────────────────────────                                    │
│                                                                          │
│  Divide dimensions into three groups:                                  │
│                                                                          │
│  Group 1 (high freq): Don't interpolate                               │
│  • Already wrap many times                                            │
│  • Preserve full resolution                                           │
│                                                                          │
│  Group 2 (mid freq): Partially interpolate (NTK-style)               │
│  • Smooth transition                                                  │
│  • Balance between resolution and extrapolation                      │
│                                                                          │
│  Group 3 (low freq): Fully interpolate (PI-style)                    │
│  • Must avoid unseen angles                                          │
│  • Accept resolution loss                                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPONENT 2: ATTENTION SCALING                                         │
│  ──────────────────────────────                                          │
│                                                                          │
│  Problem: After interpolation, attention distributions change.        │
│  Tokens that were "far" now have similar position encodings.         │
│  This makes attention too flat (entropy too high).                   │
│                                                                          │
│  Solution: Scale attention logits to compensate:                      │
│                                                                          │
│  Original: softmax(Q × Kᵀ / √d)                                       │
│  YaRN:     softmax(Q × Kᵀ / √d × t)                                  │
│                                                                          │
│  Where t < 1 (sharpen attention)                                      │
│                                                                          │
│  t = 0.1 × ln(s) + 1, where s = scale factor                         │
│  For 4× extension: t ≈ 1.14                                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPONENT 3: SMOOTH INTERPOLATION FACTOR                               │
│  ─────────────────────────────────────────                               │
│                                                                          │
│  For each dimension, compute interpolation factor λ:                  │
│                                                                          │
│  λ(d) = {                                                              │
│    1 (no interp)     if freq_d > high_freq_threshold                 │
│    0 (full interp)   if freq_d < low_freq_threshold                  │
│    smooth_ramp       otherwise                                        │
│  }                                                                     │
│                                                                          │
│  Final frequency = original_freq^(1-λ) × interpolated_freq^λ         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  IMPLEMENTATION:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  def apply_yarn(                                                        │
│      x, position, base=10000, scale=1.0,                              │
│      original_max_position=4096,                                       │
│      beta_fast=32, beta_slow=1,                                       │
│  ):                                                                     │
│      """                                                               │
│      YaRN: NTK-by-parts + attention scaling                          │
│      """                                                               │
│      head_dim = x.shape[-1]                                           │
│                                                                          │
│      # Compute base frequencies                                        │
│      dim_indices = torch.arange(0, head_dim, 2)                      │
│      freqs = 1.0 / (base ** (dim_indices / head_dim))                │
│                                                                          │
│      # Compute wavelengths                                             │
│      wavelengths = 2 * math.pi / freqs                                │
│                                                                          │
│      # Determine interpolation factors                                 │
│      low_freq_wavelen = original_max_position / beta_slow            │
│      high_freq_wavelen = original_max_position / beta_fast           │
│                                                                          │
│      # Smooth ramp between no interpolation and full interpolation   │
│      ramp = (wavelengths - high_freq_wavelen) / (                    │
│          low_freq_wavelen - high_freq_wavelen                        │
│      )                                                                 │
│      ramp = ramp.clamp(0, 1)                                         │
│                                                                          │
│      # Interpolated frequencies                                        │
│      freqs_interpolated = freqs / scale                               │
│                                                                          │
│      # Blend based on ramp                                            │
│      freqs_final = freqs * (1 - ramp) + freqs_interpolated * ramp   │
│                                                                          │
│      # Compute angles                                                  │
│      angles = position.unsqueeze(-1) * freqs_final                   │
│                                                                          │
│      # Apply rotation                                                  │
│      cos, sin = torch.cos(angles), torch.sin(angles)                 │
│      x1, x2 = x[..., ::2], x[..., 1::2]                              │
│      x_rot = torch.stack([                                            │
│          x1 * cos - x2 * sin,                                        │
│          x1 * sin + x2 * cos                                         │
│      ], dim=-1).flatten(-2)                                          │
│                                                                          │
│      # Note: Attention scaling applied separately in attention       │
│      return x_rot                                                      │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

YaRN Results

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    YARN BENCHMARK RESULTS                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  LLAMA 2 7B EXTENSION:                                                   │
│  ─────────────────────                                                   │
│                                                                          │
│  Training context: 4096                                                 │
│  Extended to: 128K                                                      │
│  Scale factor: 32×                                                      │
│                                                                          │
│  Method          Perplexity @4K   Perplexity @64K   Fine-tune Steps   │
│  ───────────────────────────────────────────────────────────────────   │
│  No extension    2.5              ∞ (fails)         N/A               │
│  PI              3.2              12.5              1000              │
│  NTK-aware       2.8              10.2              400               │
│  YaRN            2.6              6.8               400               │
│                                                                          │
│  YaRN: Best short-context preservation + best long-context quality   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  PASSKEY RETRIEVAL TEST:                                                 │
│  ────────────────────────                                                │
│                                                                          │
│  Task: Find a random number hidden in long context                    │
│  Tests: Can the model attend to any position?                         │
│                                                                          │
│  Context Length    PI Accuracy    NTK Accuracy    YaRN Accuracy       │
│  ────────────────────────────────────────────────────────────         │
│  8K               95%            97%             99%                  │
│  16K              82%            91%             98%                  │
│  32K              65%            85%             97%                  │
│  64K              41%            72%             95%                  │
│  128K             12%            48%             89%                  │
│                                                                          │
│  YaRN maintains attention accuracy at extreme lengths.               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHEN TO USE YARN:                                                       │
│  ─────────────────                                                       │
│                                                                          │
│  • Extending models 4× or more                                        │
│  • When short-context quality matters                                 │
│  • When you can do brief fine-tuning                                 │
│  • Production deployments needing reliability                        │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VI: Other Context Extension Approaches

ALiBi and Its Properties

ALiBi (Attention with Linear Biases) takes a different approach—no position in embeddings at all:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ALiBi (Attention with Linear Biases)                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Don't encode position in embeddings.                                  │
│  Instead, bias attention scores by distance:                           │
│                                                                          │
│  attention_score(i, j) = q_i · k_j - m × |i - j|                      │
│                                                                          │
│  • Nearby tokens: small penalty, high attention                       │
│  • Distant tokens: large penalty, low attention                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MULTI-HEAD SLOPES:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  Each head uses different slope m:                                    │
│                                                                          │
│  Head 0: m = 2^(-8/n)     (gentle slope, long-range)                 │
│  Head 1: m = 2^(-8/n × 2) (steeper)                                  │
│  Head 2: m = 2^(-8/n × 3) (steeper still)                            │
│  ...                                                                   │
│  Head n: m = 2^(-8)       (very steep, local only)                   │
│                                                                          │
│  This gives heads different "receptive fields."                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  EXTRAPOLATION PROPERTIES:                                               │
│  ─────────────────────────                                               │
│                                                                          │
│  The bias formula works for ANY distance!                             │
│  No position embeddings to extrapolate.                               │
│                                                                          │
│  BUT: Model learns attention patterns assuming certain biases.        │
│  At extreme lengths, attention becomes too suppressed.               │
│  Practical limit: ~2-4× training length.                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ALIBI VS ROPE:                                                         │
│  ──────────────                                                          │
│                                                                          │
│  ALiBi:                                                                │
│  + Simpler (no rotation computation)                                  │
│  + Some natural extrapolation                                         │
│  - Less expressive position encoding                                  │
│  - Doesn't scale as well to very long contexts                       │
│                                                                          │
│  RoPE:                                                                 │
│  + More expressive relative positions                                 │
│  + Better scaling with YaRN etc.                                      │
│  - More computation                                                   │
│  - Requires explicit scaling for extension                           │
│                                                                          │
│  Most modern LLMs use RoPE. ALiBi used by: Falcon, MPT, BLOOM.       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Sliding Window Attention

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SLIDING WINDOW ATTENTION                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Limit attention to a local window around each token.                 │
│  Used by: Mistral, Mixtral (with some global attention layers).       │
│                                                                          │
│  Standard attention (all-to-all):                                      │
│  Token 5 attends to: 0, 1, 2, 3, 4, 5                                │
│  Token 10 attends to: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10              │
│                                                                          │
│  Sliding window (window=4):                                            │
│  Token 5 attends to: 2, 3, 4, 5                                      │
│  Token 10 attends to: 7, 8, 9, 10                                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  HOW IT ENABLES LONGER CONTEXT:                                         │
│  ──────────────────────────────                                          │
│                                                                          │
│  KV cache only stores window_size keys/values per layer.             │
│  Memory: O(window_size × num_layers) instead of O(seq_len × layers) │
│                                                                          │
│  For 4K window, 32K sequence:                                         │
│  Standard: 32K × 32 layers = 1M KV entries                           │
│  Sliding:  4K × 32 layers = 128K KV entries                          │
│                                                                          │
│  8× memory reduction!                                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  EFFECTIVE CONTEXT VIA STACKING:                                        │
│  ────────────────────────────────                                        │
│                                                                          │
│  Layer 1: Token 32 sees tokens 29-32 (window of 4)                   │
│  Layer 2: Token 32's representation includes info from 25-32         │
│           (layer 1 mixed 29-32, which mixed 25-28...)               │
│  Layer 3: Effective range: 21-32                                      │
│  ...                                                                   │
│  Layer 8: Effective range: 1-32 (whole sequence!)                    │
│                                                                          │
│  Effective context = window_size × num_layers                         │
│                                                                          │
│  With 4K window and 32 layers: 128K effective context!               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMBINING WITH ROPE SCALING:                                           │
│  ────────────────────────────                                            │
│                                                                          │
│  Sliding window + RoPE scaling = very long context                    │
│                                                                          │
│  Mistral approach:                                                     │
│  • 4K sliding window for most layers                                  │
│  • Some global attention layers                                       │
│  • RoPE scaled for global layers                                      │
│                                                                          │
│  This achieves 32K+ context with manageable memory.                  │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VII: Practical Implementation

Extending Your Own Model

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    PRACTICAL CONTEXT EXTENSION                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  STEP 1: CHOOSE YOUR METHOD                                             │
│  ──────────────────────────                                              │
│                                                                          │
│  Extension Factor    Recommended Method    Fine-tuning Needed          │
│  ─────────────────────────────────────────────────────────────         │
│  2×                  Dynamic NTK          Minimal (100-500 steps)      │
│  4×                  YaRN                 Moderate (500-1000 steps)    │
│  8×+                 YaRN                 More (1000-2000 steps)       │
│                                                                          │
│  For zero-shot (no fine-tuning): Dynamic NTK or NTK-by-parts.        │
│  For best quality: YaRN with fine-tuning.                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  STEP 2: MODIFY ROPE COMPUTATION                                        │
│  ───────────────────────────────                                         │
│                                                                          │
│  # In your model's attention module:                                   │
│                                                                          │
│  class ExtendedRoPE:                                                    │
│      def __init__(                                                      │
│          self,                                                          │
│          dim: int,                                                      │
│          max_position: int = 4096,                                     │
│          base: float = 10000.0,                                        │
│          scaling_type: str = "yarn",  # or "ntk", "pi"                │
│          scale: float = 1.0,                                           │
│      ):                                                                 │
│          self.dim = dim                                                 │
│          self.max_position = max_position                              │
│          self.base = base                                               │
│          self.scaling_type = scaling_type                              │
│          self.scale = scale                                             │
│                                                                          │
│          # Precompute frequencies                                       │
│          self._compute_frequencies()                                    │
│                                                                          │
│      def _compute_frequencies(self):                                   │
│          dim_indices = torch.arange(0, self.dim, 2)                   │
│                                                                          │
│          if self.scaling_type == "ntk":                                │
│              # NTK: scale base                                         │
│              scaled_base = self.base * self.scale                     │
│              self.freqs = 1.0 / (scaled_base ** (dim_indices / self.dim))
│                                                                          │
│          elif self.scaling_type == "pi":                               │
│              # PI: scale will be applied to positions                 │
│              self.freqs = 1.0 / (self.base ** (dim_indices / self.dim))│
│                                                                          │
│          elif self.scaling_type == "yarn":                             │
│              # YaRN: interpolation ramp                               │
│              self.freqs = self._compute_yarn_freqs(dim_indices)       │
│                                                                          │
│      def _compute_yarn_freqs(self, dim_indices):                       │
│          base_freqs = 1.0 / (self.base ** (dim_indices / self.dim))   │
│          wavelengths = 2 * math.pi / base_freqs                       │
│                                                                          │
│          # YaRN parameters                                             │
│          beta_fast = 32                                                │
│          beta_slow = 1                                                  │
│          low_freq_wavelen = self.max_position / beta_slow             │
│          high_freq_wavelen = self.max_position / beta_fast            │
│                                                                          │
│          # Interpolation ramp                                          │
│          ramp = (wavelengths - high_freq_wavelen) / (                 │
│              low_freq_wavelen - high_freq_wavelen                     │
│          )                                                              │
│          ramp = ramp.clamp(0, 1)                                       │
│                                                                          │
│          # Blend frequencies                                           │
│          freqs_interp = base_freqs / self.scale                       │
│          freqs = base_freqs * (1 - ramp) + freqs_interp * ramp       │
│          return freqs                                                   │
│                                                                          │
│      def forward(self, x, positions):                                  │
│          if self.scaling_type == "pi":                                │
│              positions = positions / self.scale                       │
│                                                                          │
│          angles = positions.unsqueeze(-1) * self.freqs                │
│          cos, sin = torch.cos(angles), torch.sin(angles)              │
│                                                                          │
│          # Apply rotation                                              │
│          x1, x2 = x[..., ::2], x[..., 1::2]                          │
│          x_rot = torch.stack([                                        │
│              x1 * cos - x2 * sin,                                     │
│              x1 * sin + x2 * cos                                      │
│          ], dim=-1).flatten(-2)                                       │
│          return x_rot                                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  STEP 3: FINE-TUNE ON LONG SEQUENCES                                    │
│  ────────────────────────────────────                                    │
│                                                                          │
│  Training data requirements:                                           │
│  • Sequences at target length                                         │
│  • Mix of lengths (don't only train on max length)                   │
│  • Quality long-form content                                          │
│                                                                          │
│  Example fine-tuning config:                                           │
│                                                                          │
│  training_args = {                                                      │
│      "max_steps": 1000,                                                │
│      "per_device_train_batch_size": 1,                                │
│      "gradient_accumulation_steps": 8,                                │
│      "learning_rate": 2e-5,                                            │
│      "max_seq_length": 32768,  # Target length                       │
│      "bf16": True,                                                     │
│      "gradient_checkpointing": True,  # Essential for memory         │
│  }                                                                      │
│                                                                          │
│  Data mixing:                                                           │
│  • 30% original-length sequences (preserve short-context)            │
│  • 40% medium-length sequences                                        │
│  • 30% full target-length sequences                                   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Using Extended Models with vLLM/Transformers

Python
# Using YaRN-extended model with HuggingFace Transformers

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model with rope_scaling configuration
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    rope_scaling={
        "type": "yarn",
        "factor": 4.0,  # 4× extension
        "original_max_position_embeddings": 4096,
    },
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Update model's max position
model.config.max_position_embeddings = 16384

# Now model can handle 16K context
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.model_max_length = 16384

# Generate with long context
long_text = "..." * 10000  # Long input
inputs = tokenizer(long_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
Python
# Using with vLLM

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    max_model_len=32768,  # Extended context
    rope_scaling={
        "type": "yarn",
        "factor": 8.0,
        "original_max_position_embeddings": 4096,
    },
)

sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
outputs = llm.generate(["Long prompt here..."], sampling_params)

Part VIII: Memory Considerations

Long Context Memory Requirements

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    MEMORY SCALING WITH CONTEXT                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KV CACHE MEMORY:                                                        │
│  ────────────────                                                        │
│                                                                          │
│  KV cache per token = 2 × num_layers × num_kv_heads × head_dim × dtype │
│                                                                          │
│  Example: Llama 2 7B                                                    │
│  • 32 layers                                                           │
│  • 32 KV heads                                                         │
│  • 128 head dim                                                        │
│  • BF16 (2 bytes)                                                      │
│                                                                          │
│  Per token: 2 × 32 × 32 × 128 × 2 = 524 KB                            │
│                                                                          │
│  Context Length    KV Cache Size    Total (with model)                 │
│  ─────────────────────────────────────────────────────                  │
│  4K                2.0 GB           16 GB                              │
│  8K                4.0 GB           18 GB                              │
│  16K               8.0 GB           22 GB                              │
│  32K               16.0 GB          30 GB                              │
│  64K               32.0 GB          46 GB                              │
│  128K              64.0 GB          78 GB                              │
│                                                                          │
│  KV cache dominates at long contexts!                                 │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ATTENTION COMPUTATION MEMORY:                                           │
│  ─────────────────────────────                                           │
│                                                                          │
│  Standard attention materializes N×N matrix:                          │
│  Memory = seq_len² × num_heads × dtype                               │
│                                                                          │
│  32K context: 32K × 32K × 32 × 2 = 64 GB just for attention!         │
│                                                                          │
│  FlashAttention solves this:                                           │
│  Memory = O(seq_len) instead of O(seq_len²)                          │
│                                                                          │
│  CRITICAL: Use FlashAttention for long context. No alternative.      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MEMORY OPTIMIZATION STRATEGIES:                                         │
│  ────────────────────────────────                                        │
│                                                                          │
│  1. Quantized KV cache                                                 │
│     INT8 KV: 2× reduction                                             │
│     INT4 KV: 4× reduction (quality tradeoff)                         │
│                                                                          │
│  2. Sliding window attention                                           │
│     Fixed KV cache size regardless of sequence length                │
│                                                                          │
│  3. KV cache offloading                                                │
│     Store old KV in CPU RAM, load on demand                          │
│     Adds latency but enables very long contexts                      │
│                                                                          │
│  4. PagedAttention (vLLM)                                              │
│     Non-contiguous KV cache allocation                                │
│     Reduces fragmentation waste                                       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part IX: Recent Innovations (2024-2025)

LongRoPE: Non-Uniform Position Scaling

LongRoPE introduced a key insight: different RoPE dimensions should be scaled differently based on their importance to the model, not just their frequency:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    LONGROPE: SEARCH-BASED SCALING                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE KEY INSIGHT:                                                        │
│  ────────────────                                                        │
│                                                                          │
│  YaRN assumes frequency determines importance.                          │
│  But empirically, some mid-frequency dimensions matter more than others.│
│                                                                          │
│  LongRoPE: Search for optimal per-dimension rescaling factors.          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  NON-UNIFORM SCALING:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  YaRN uses smooth ramp (frequency-based):                               │
│  λ(d) = function of frequency only                                      │
│                                                                          │
│  LongRoPE searches for optimal λ per dimension:                         │
│  λ = [λ₀, λ₁, λ₂, ..., λ_{d/2}]                                        │
│                                                                          │
│  Search objective: Minimize perplexity on validation set                │
│  Method: Evolutionary search over λ values                              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TWO-STAGE EXTENSION:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  Stage 1: Extend to intermediate length (e.g., 4K → 64K)               │
│  • Search for optimal λ₁                                                │
│  • Fine-tune briefly                                                    │
│                                                                          │
│  Stage 2: Extend to target length (e.g., 64K → 2048K)                  │
│  • Search for optimal λ₂                                                │
│  • Fine-tune briefly                                                    │
│                                                                          │
│  Progressive extension works better than single-stage jump.             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RESULTS:                                                                │
│                                                                          │
│  • Extended LLaMA2-7B from 4K to 2048K (512× extension!)               │
│  • Maintained passkey retrieval accuracy at extreme lengths            │
│  • Better perplexity than YaRN at equivalent extensions                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

LongRoPE2: Efficient Extension with Minimal Data

LongRoPE2 (2024-2025) significantly improved training efficiency for context extension:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    LONGROPE2: 80× MORE EFFICIENT                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE EFFICIENCY BREAKTHROUGH:                                            │
│  ────────────────────────────                                            │
│                                                                          │
│  Previous approaches (like Meta's Llama 3 128K):                        │
│  • Required 800B tokens of long-context fine-tuning                    │
│  • Expensive compute cost                                               │
│                                                                          │
│  LongRoPE2 achieves similar quality with:                               │
│  • Only 10B tokens (80× reduction)                                     │
│  • Maintains 98.5% short-context performance                           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  KEY INNOVATIONS:                                                        │
│  ────────────────                                                        │
│                                                                          │
│  1. NEEDLE-DRIVEN PERPLEXITY (NDP):                                     │
│     Instead of standard perplexity, optimize for retrieval.             │
│     Better correlates with long-context task performance.               │
│                                                                          │
│  2. EVOLUTIONARY SEARCH WITH CONSTRAINTS:                                │
│     Search space reduced via theoretical bounds.                        │
│     Faster convergence to good scaling factors.                        │
│                                                                          │
│  3. MIXED-LENGTH TRAINING:                                               │
│     Careful mixture of short and long sequences.                        │
│     Preserves short-context quality while extending.                   │
│                                                                          │
│  4. PROGRESSIVE TEMPERATURE ADJUSTMENT:                                  │
│     Dynamic attention temperature during training.                      │
│     Smoother adaptation to longer contexts.                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RESULTS ON LLAMA3-8B:                                                   │
│                                                                          │
│  Metric                         LongRoPE2    Meta's Method              │
│  ─────────────────────────────────────────────────────────────         │
│  Target context                 128K         128K                       │
│  Training tokens                10B          800B                       │
│  Short-context retention        98.5%        ~99%                       │
│  RULER (128K)                   Comparable   Baseline                   │
│  Needle-in-haystack             97%+         97%+                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  PRODUCTION ADOPTION:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  LongRoPE2 techniques adopted in:                                       │
│  • Phi-4-mini (Microsoft)                                              │
│  • Phi-4-multimodal                                                    │
│  • Various open-source long-context models                             │
│                                                                          │
│  Practical for organizations without massive compute budgets.          │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Dual Chunk Attention (DCA)

Another approach for extending context without fine-tuning:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    DUAL CHUNK ATTENTION (DCA)                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────────                                                        │
│                                                                          │
│  Process long sequences in chunks, with special handling for           │
│  cross-chunk attention.                                                  │
│                                                                          │
│  Sequence: [chunk_1 | chunk_2 | chunk_3 | ...]                         │
│                                                                          │
│  For attention within chunk: Use standard RoPE                          │
│  For attention across chunks: Use interpolated positions               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  INTRA-CHUNK vs INTER-CHUNK:                                            │
│  ────────────────────────────                                            │
│                                                                          │
│  Token at position 5000 (in chunk 2, assuming 4K chunks):              │
│                                                                          │
│  Attending to token at position 4500 (same chunk):                     │
│  • Relative position: 500                                              │
│  • Use standard RoPE angle for 500                                     │
│                                                                          │
│  Attending to token at position 500 (chunk 1):                         │
│  • Actual distance: 4500                                               │
│  • Map to interpolated position within training range                  │
│                                                                          │
│  Local attention (within chunk) uses FULL resolution.                  │
│  Global attention (across chunks) uses interpolated positions.         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BENEFITS:                                                               │
│  ─────────                                                               │
│                                                                          │
│  • Training-free (no fine-tuning needed)                               │
│  • Preserves local pattern resolution completely                       │
│  • Only long-range attention is interpolated                           │
│  • Works with existing RoPE models directly                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Updated Method Comparison

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    CONTEXT EXTENSION METHODS (2025)                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Method          Max Ext.  Fine-tune   Short Ctx    Key Innovation     │
│  ─────────────────────────────────────────────────────────────────────  │
│  PI              8×        Required    Degraded     Position scaling   │
│  NTK-aware       8×        Optional    Good         Base scaling       │
│  Dynamic NTK     16×       Minimal     Preserved    Dynamic scaling    │
│  YaRN            32×       Brief       Very good    Ramp + temperature │
│  LongRoPE        512×      Brief       Good         Search-based λ     │
│  LongRoPE2       32×       Very brief  Excellent    80× efficiency     │
│  DCA             8×        None        Preserved    Chunk-based        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RECOMMENDATION BY SCENARIO:                                             │
│                                                                          │
│  Zero-shot (no fine-tuning):                                            │
│  • Up to 2×: Dynamic NTK                                               │
│  • Up to 4×: DCA                                                       │
│                                                                          │
│  Brief fine-tuning (100-1K steps):                                      │
│  • Up to 8×: YaRN                                                      │
│  • Up to 32×: LongRoPE2                                                │
│                                                                          │
│  Maximum extension with fine-tuning:                                    │
│  • Up to 512×: LongRoPE (requires careful tuning)                     │
│                                                                          │
│  Best efficiency (limited compute):                                     │
│  • LongRoPE2 (10B tokens for 128K context)                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Summary

Context extension enables LLMs to operate far beyond their training length. The key techniques form a progression of sophistication:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    KEY TAKEAWAYS                                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE PROBLEM:                                                            │
│  • Models trained on 4K tokens fail at 8K+ (unseen positions)         │
│  • RoPE low-frequency dimensions hit unseen rotation angles           │
│  • Naive extrapolation produces garbage                               │
│                                                                          │
│  POSITION INTERPOLATION (PI):                                           │
│  • Scale positions to fit within training range                       │
│  • Works but degrades short-context quality                           │
│  • Requires fine-tuning                                               │
│                                                                          │
│  NTK-AWARE SCALING:                                                      │
│  • Scale base instead of positions                                    │
│  • High frequencies unchanged, low frequencies scaled                 │
│  • Better preservation of local patterns                              │
│                                                                          │
│  YARN:                                                                  │
│  • NTK-by-parts with smooth interpolation ramp                       │
│  • Attention temperature adjustment                                   │
│  • Best short+long context quality                                   │
│  • Brief fine-tuning recommended                                     │
│                                                                          │
│  LONGROPE/LONGROPE2 (2024-2025):                                        │
│  • Search-based per-dimension scaling factors                        │
│  • LongRoPE: Up to 512× extension with progressive stages           │
│  • LongRoPE2: 80× more efficient (10B vs 800B tokens)              │
│  • Adopted in Phi-4 models                                           │
│                                                                          │
│  PRACTICAL GUIDELINES:                                                   │
│  • 2× extension: Dynamic NTK, minimal fine-tuning                    │
│  • 4×-8× extension: YaRN with 500-1000 steps fine-tuning            │
│  • 16×+ extension: LongRoPE2 (if compute-limited)                   │
│  • Always use FlashAttention for long context                        │
│  • Monitor KV cache memory (often the bottleneck)                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  STATE OF CONTEXT WINDOWS (2025):                                        │
│                                                                          │
│  Model              Context      Notes                                  │
│  ────────────────────────────────────────────────────────────────      │
│  Magic LTM-2        100M         1000× cheaper than Llama 3.1 405B    │
│  Llama 4 Scout      10M          Industry-best open, stepwise RoPE    │
│  Grok-4-fast        2M           Long-horizon RL training              │
│  Gemini 2.0 Pro     2M           Google's largest context              │
│  Llama 4 Maverick   1M           128 experts, multimodal               │
│  Qwen3-Next         1M           256K native, extendable to 1M        │
│  Gemini 2.5 Flash   1M           Reasoning + multimodal                │
│  Claude 4 Sonnet    200K (1M β)  1M token beta available              │
│  GPT-5              400K in      128K output                           │
│  Kimi K2            128K         Moonshot AI                           │
│                                                                          │
│  ULTRA-LONG CONTEXT BREAKTHROUGH (Magic LTM-2):                        │
│  • 100M tokens = 10M lines of code or 750 novels                      │
│  • Sequence-dimension algorithm is 1000× cheaper than Llama 3.1 405B │
│  • HashHop evaluation: Random hash pairs force full context usage     │
│                                                                          │
│  Trend: Context windows 100× larger than 2023 (1M → 100M)            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles