Skip to main content
Back to Blog

RoPE: Rotary Position Embeddings Explained

A comprehensive mathematical deep dive into Rotary Position Embeddings (RoPE)—the position encoding method that powers Llama, Mistral, Qwen, and most modern LLMs. Complete derivations, proofs, implementation, and the mathematics of context extension.

15 min read
Share:

Overview

Rotary Position Embedding (RoPE), introduced by Su et al. in the 2021 paper "RoFormer: Enhanced Transformer with Rotary Position Embedding," has become the dominant position encoding method in modern large language models. Llama 1/2/3/4, Mistral, Mixtral, Qwen, Yi, DeepSeek, Gemma, Phi—virtually every major open-source LLM uses RoPE.

ModelRoPE BaseMax ContextNotes
Llama 110,0002,048Original implementation
Llama 210,0004,096Extended context
Llama 3500,0008,192Higher base for extension
Llama 3.1500,000128,000With YaRN scaling
Mistral 7B10,0008,192Sliding window attention
Qwen 2.51,000,00032,768+Very high base
DeepSeek-V210,000128,000With YaRN

This post provides a complete mathematical treatment of RoPE: foundational concepts, full derivations, production implementation, and context extension methods.

Prerequisites: Linear algebra (matrices, dot products), complex numbers, basic calculus. See Positional Embeddings for background.

Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding (Su et al., 2021)


Part I: Mathematical Foundations

The Position Encoding Problem

In self-attention, we compute scores between queries and keys:

amn=exp(qmkn/d)j=1Nexp(qmkj/d)a_{mn} = \frac{\exp(q_m^\top k_n / \sqrt{d})}{\sum_{j=1}^{N} \exp(q_m^\top k_j / \sqrt{d})}

where qm=Wqxmq_m = W_q x_m is the query at position mm, kn=Wkxnk_n = W_k x_n is the key at position nn, and dd is the head dimension.

The core attention operation is the inner product qmknq_m^\top k_n. Without position information, this depends only on the content of tokens at positions mm and nn, not their positions.

Goal: Design a function f(x,m)f(x, m) that incorporates position mm into vector xx such that:

f(q,m),f(k,n)=g(q,k,mn)\langle f(q, m), f(k, n) \rangle = g(q, k, m-n)

The inner product should depend on relative position (mn)(m-n), not absolute positions mm and nn individually.

Complex Numbers and Rotation

RoPE's key insight: rotation naturally encodes relative position.

Euler's Formula

eiθ=cosθ+isinθe^{i\theta} = \cos\theta + i\sin\theta

Derived from Taylor series:

eiθ=k=0(iθ)kk!=k=0(1)kθ2k(2k)!cosθ+ik=0(1)kθ2k+1(2k+1)!sinθe^{i\theta} = \sum_{k=0}^{\infty} \frac{(i\theta)^k}{k!} = \underbrace{\sum_{k=0}^{\infty} \frac{(-1)^k \theta^{2k}}{(2k)!}}_{\cos\theta} + i\underbrace{\sum_{k=0}^{\infty} \frac{(-1)^k \theta^{2k+1}}{(2k+1)!}}_{\sin\theta}

Complex Multiplication as Rotation

A complex number z=a+biz = a + bi in polar form:

z=reiϕ=r(cosϕ+isinϕ)z = r e^{i\phi} = r(\cos\phi + i\sin\phi)

where r=z=a2+b2r = |z| = \sqrt{a^2 + b^2} (magnitude) and ϕ=arg(z)=arctan(b/a)\phi = \arg(z) = \arctan(b/a) (phase).

Multiplying by eiθe^{i\theta} rotates zz by angle θ\theta:

zeiθ=reiϕeiθ=rei(ϕ+θ)z \cdot e^{i\theta} = r e^{i\phi} \cdot e^{i\theta} = r e^{i(\phi + \theta)}

Magnitude preserved, angle shifted—this is rotation.

The 2D Rotation Matrix

For a real 2D vector (x,y)(x, y), rotation by angle θ\theta:

R(θ)=(cosθsinθsinθcosθ)R(\theta) = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix}

R(θ)(xy)=(xcosθysinθxsinθ+ycosθ)R(\theta) \begin{pmatrix} x \\ y \end{pmatrix} = \begin{pmatrix} x\cos\theta - y\sin\theta \\ x\sin\theta + y\cos\theta \end{pmatrix}

Key properties:

PropertyFormulaImplication
OrthogonalityR(θ)R(θ)=IR(\theta)^\top R(\theta) = IR(θ)1=R(θ)=R(θ)R(\theta)^{-1} = R(\theta)^\top = R(-\theta)
CompositionR(α)R(β)=R(α+β)R(\alpha) R(\beta) = R(\alpha + \beta)Rotations add
Determinantdet(R(θ))=1\det(R(\theta)) = 1Orientation preserved
Norm preservationR(θ)x=x\|R(\theta)x\| = \|x\|Length unchanged

Equivalence: Complex ↔ Matrix

Viewing (x,y)(x, y) as z=x+iyz = x + iy:

zeiθ=(x+iy)(cosθ+isinθ)=(xcosθysinθ)+i(xsinθ+ycosθ)z \cdot e^{i\theta} = (x + iy)(\cos\theta + i\sin\theta) = (x\cos\theta - y\sin\theta) + i(x\sin\theta + y\cos\theta)

This equals the matrix rotation result. Complex multiplication = matrix rotation.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ROTATION: TWO EQUIVALENT VIEWS                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  COMPLEX VIEW:                        MATRIX VIEW:                       │
│  ─────────────                        ────────────                       │
│                                                                          │
│  z = x + iy                           v = [x, y]ᵀ                       │
│                                                                          │
│  Rotate by θ:                         Rotate by θ:                       │
│                                                                          │
│  z' = z · e^(iθ)                      v' = R(θ) · v                     │
│     = z · (cosθ + i·sinθ)                                               │
│                                       R(θ) = [cosθ  -sinθ]              │
│                                              [sinθ   cosθ]              │
│                                                                          │
│  Result:                              Result:                            │
│  z' = (x·cosθ - y·sinθ)              v' = [x·cosθ - y·sinθ]            │
│       + i(x·sinθ + y·cosθ)                [x·sinθ + y·cosθ]            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SAME RESULT: Real part = first component, Imaginary = second           │
│                                                                          │
│  Implementation choice:                                                  │
│  • Complex: elegant, single multiplication                              │
│  • Matrix: explicit, works without complex support                      │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part II: Deriving RoPE

The Inner Product Under Rotation

Consider two 2D vectors q=(q0,q1)q = (q_0, q_1) and k=(k0,k1)k = (k_0, k_1).

Rotate qq by angle α\alpha and kk by angle β\beta:

q=R(α)q,k=R(β)kq' = R(\alpha) q, \quad k' = R(\beta) k

Their inner product:

(q)k=(R(α)q)(R(β)k)(q')^\top k' = (R(\alpha) q)^\top (R(\beta) k)

Using (AB)=BA(AB)^\top = B^\top A^\top:

=qR(α)R(β)k= q^\top R(\alpha)^\top R(\beta) k

Using R(α)=R(α)R(\alpha)^\top = R(-\alpha) (orthogonality):

=qR(α)R(β)k= q^\top R(-\alpha) R(\beta) k

Using R(α)R(β)=R(α+β)R(\alpha)R(\beta) = R(\alpha + \beta) (composition):

=qR(βα)k= q^\top R(\beta - \alpha) k

Key result: Inner product depends only on (βα)(\beta - \alpha), the difference of rotation angles.

Applying to Position Encoding

Set rotation angle proportional to position:

  • Query at position mm: rotate by mθm\theta
  • Key at position nn: rotate by nθn\theta

Then:

qmkn=qR(nθmθ)k=qR((nm)θ)kq_m^\top k_n = q^\top R(n\theta - m\theta) k = q^\top R((n-m)\theta) k

The inner product depends only on relative position (nm)(n - m).

The Complete RoPE Formulation

For high-dimensional vectors (dimension dd), apply independent 2D rotations to pairs of dimensions. Split into d/2d/2 pairs, each rotated by a different frequency.

Block-Diagonal Rotation Matrix

The full rotation matrix is block-diagonal with d/2d/2 rotation blocks:

RΘ,md=(R(mθ0)000R(mθ1)000R(mθd/21))\mathbf{R}_{\Theta,m}^d = \begin{pmatrix} R(m\theta_0) & 0 & \cdots & 0 \\ 0 & R(m\theta_1) & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & R(m\theta_{d/2-1}) \end{pmatrix}

where each R(mθj)R(m\theta_j) is a 2×22 \times 2 rotation matrix:

R(mθj)=(cosmθjsinmθjsinmθjcosmθj)R(m\theta_j) = \begin{pmatrix} \cos m\theta_j & -\sin m\theta_j \\ \sin m\theta_j & \cos m\theta_j \end{pmatrix}

Expanded form (showing all elements):

RΘ,md=(cosmθ0sinmθ00000sinmθ0cosmθ0000000cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθd/21sinmθd/210000sinmθd/21cosmθd/21)\mathbf{R}_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_0 & -\sin m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_0 & \cos m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_1 & -\sin m\theta_1 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_1 & \cos m\theta_1 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2-1} & -\sin m\theta_{d/2-1} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2-1} & \cos m\theta_{d/2-1} \end{pmatrix}

Compact Notation

Using block diagonal notation:

RΘ,md=j=0d/21R(mθj)=diag(R(mθ0),R(mθ1),,R(mθd/21))\mathbf{R}_{\Theta,m}^d = \bigoplus_{j=0}^{d/2-1} R(m\theta_j) = \text{diag}(R(m\theta_0), R(m\theta_1), \ldots, R(m\theta_{d/2-1}))

The Frequency Schedule

Each dimension pair jj rotates at frequency θj\theta_j:

θj=b2j/d=1b2j/d\theta_j = b^{-2j/d} = \frac{1}{b^{2j/d}}

where b=10000b = 10000 is the base (hyperparameter).

Frequencies form a geometric sequence:

Dimension pair jjFrequency θj\theta_jFormula
j=0j = 01.01.0100000=110000^0 = 1
j=1j = 10.9995\approx 0.9995100002/d10000^{-2/d}
j=d/4j = d/40.01\approx 0.01100000.510000^{-0.5}
j=d/21j = d/2 - 10.0001\approx 0.000110000(d2)/d10000110000^{-(d-2)/d} \approx 10000^{-1}

Part III: The Core Theorem

Theorem: RoPE Encodes Relative Position

Statement: For query qq at position mm and key kk at position nn, with RoPE-transformed vectors q~m=RΘ,mdq\tilde{q}_m = \mathbf{R}_{\Theta,m}^d q and k~n=RΘ,ndk\tilde{k}_n = \mathbf{R}_{\Theta,n}^d k:

q~mk~n=qRΘ,nmdk\boxed{\tilde{q}_m^\top \tilde{k}_n = q^\top \mathbf{R}_{\Theta,n-m}^d k}

The inner product depends only on relative position (nm)(n-m).

Proof:

Step 1: Start with the inner product definition.

q~mk~n=(RΘ,mdq)(RΘ,ndk)\tilde{q}_m^\top \tilde{k}_n = (\mathbf{R}_{\Theta,m}^d q)^\top (\mathbf{R}_{\Theta,n}^d k)

Step 2: Apply transpose property (AB)=BA(AB)^\top = B^\top A^\top.

=q(RΘ,md)RΘ,ndk= q^\top (\mathbf{R}_{\Theta,m}^d)^\top \mathbf{R}_{\Theta,n}^d k

Step 3: Each 2×2 block is orthogonal, so R(α)=R(α)R(\alpha)^\top = R(-\alpha).

(RΘ,md)=j=0d/21R(mθj)=RΘ,md(\mathbf{R}_{\Theta,m}^d)^\top = \bigoplus_{j=0}^{d/2-1} R(-m\theta_j) = \mathbf{R}_{\Theta,-m}^d

Step 4: Substitute.

q~mk~n=qRΘ,mdRΘ,ndk\tilde{q}_m^\top \tilde{k}_n = q^\top \mathbf{R}_{\Theta,-m}^d \mathbf{R}_{\Theta,n}^d k

Step 5: Apply composition property R(α)R(β)=R(α+β)R(\alpha)R(\beta) = R(\alpha+\beta) to each block.

RΘ,mdRΘ,nd=j=0d/21R(mθj)R(nθj)=j=0d/21R((nm)θj)=RΘ,nmd\mathbf{R}_{\Theta,-m}^d \mathbf{R}_{\Theta,n}^d = \bigoplus_{j=0}^{d/2-1} R(-m\theta_j) R(n\theta_j) = \bigoplus_{j=0}^{d/2-1} R((n-m)\theta_j) = \mathbf{R}_{\Theta,n-m}^d

Step 6: Final result.

q~mk~n=qRΘ,nmdk\tilde{q}_m^\top \tilde{k}_n = q^\top \mathbf{R}_{\Theta,n-m}^d k \quad \blacksquare

Alternative: Complex Number Proof

For each 2D pair jj, let qc(j)=q2j+iq2j+1q_c^{(j)} = q_{2j} + i q_{2j+1} and kc(j)=k2j+ik2j+1k_c^{(j)} = k_{2j} + i k_{2j+1}.

RoPE rotation in complex form:

q~c(j)=qc(j)eimθj,k~c(j)=kc(j)einθj\tilde{q}_c^{(j)} = q_c^{(j)} \cdot e^{im\theta_j}, \quad \tilde{k}_c^{(j)} = k_c^{(j)} \cdot e^{in\theta_j}

The real inner product of the 2D pair equals:

Re(q~c(j)k~c(j))\text{Re}(\tilde{q}_c^{(j)} \cdot \overline{\tilde{k}_c^{(j)}})

Computing:

q~c(j)k~c(j)=qc(j)eimθjkc(j)einθj=qc(j)eimθjkc(j)einθj=qc(j)kc(j)ei(mn)θj\tilde{q}_c^{(j)} \cdot \overline{\tilde{k}_c^{(j)}} = q_c^{(j)} e^{im\theta_j} \cdot \overline{k_c^{(j)} e^{in\theta_j}} = q_c^{(j)} e^{im\theta_j} \cdot \overline{k_c^{(j)}} \cdot e^{-in\theta_j} = q_c^{(j)} \cdot \overline{k_c^{(j)}} \cdot e^{i(m-n)\theta_j}

Summing over all pairs:

q~mk~n=j=0d/21Re(qc(j)kc(j)ei(mn)θj)\tilde{q}_m^\top \tilde{k}_n = \sum_{j=0}^{d/2-1} \text{Re}\left(q_c^{(j)} \cdot \overline{k_c^{(j)}} \cdot e^{i(m-n)\theta_j}\right)

Depends only on (mn)(m-n). \blacksquare


Part IV: Frequency Analysis

Wavelength and Period

Each dimension pair jj has frequency θj=b2j/d\theta_j = b^{-2j/d}. The wavelength (positions per complete 2π2\pi rotation):

λj=2πθj=2πb2j/d\lambda_j = \frac{2\pi}{\theta_j} = 2\pi \cdot b^{2j/d}

For b=10000b = 10000 and d=128d = 128 (typical head dimension):

Dimension pairFrequency θj\theta_jWavelength λj\lambda_jInterpretation
j=0j = 01.01.02π6.32\pi \approx 6.3Distinguishes positions 1 apart
j=16j = 160.056\approx 0.056112\approx 112Medium-range patterns
j=32j = 320.0032\approx 0.00322,000\approx 2,000Long-range patterns
j=63j = 630.0001\approx 0.000162,800\approx 62,800Global position

Maximum Distinguishable Position

Two positions become indistinguishable when all dimension pairs complete full rotations. The limiting factor is the lowest frequency:

mmax2πθd/21=2πb(d2)/d2πbm_{\max} \approx \frac{2\pi}{\theta_{d/2-1}} = 2\pi \cdot b^{(d-2)/d} \approx 2\pi \cdot b

For b=10000b = 10000: mmax62,832m_{\max} \approx 62,832 positions.

This is why higher base values enable longer contexts: Llama 3 uses b=500,000b = 500,000, giving mmax3.14Mm_{\max} \approx 3.14M positions.

Fourier Interpretation

The RoPE inner product can be viewed as a Fourier-like decomposition:

q~mk~n=j=0d/21Re(qc(j)kc(j)ei(mn)θj)\tilde{q}_m^\top \tilde{k}_n = \sum_{j=0}^{d/2-1} \text{Re}\left(q_c^{(j)} \cdot \overline{k_c^{(j)}} \cdot e^{i(m-n)\theta_j}\right)

Expanding:

=j=0d/21qc(j)kc(j)cos(ϕqk(j)+(mn)θj)= \sum_{j=0}^{d/2-1} |q_c^{(j)}| |k_c^{(j)}| \cos(\phi_{qk}^{(j)} + (m-n)\theta_j)

where ϕqk(j)=arg(qc(j))arg(kc(j))\phi_{qk}^{(j)} = \arg(q_c^{(j)}) - \arg(k_c^{(j)}) is the phase difference.

Interpretation: Sum of cosines at different frequencies—a Fourier series where:

  • Coefficients (qc(j)kc(j)|q_c^{(j)}| |k_c^{(j)}|) depend on content
  • Frequencies (θj\theta_j) are fixed by architecture
  • Variable is relative position (mn)(m-n)
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ROPE FREQUENCY SPECTRUM                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  θ_j = b^(-2j/d)  where b = 10000, d = head_dim                        │
│                                                                          │
│  Frequency                                                               │
│  (log scale)                                                            │
│      │                                                                   │
│   1  │ ●  θ₀ = 1           (high freq: local patterns)                 │
│      │  ╲                                                                │
│ 0.1  │   ╲                                                               │
│      │    ╲                                                              │
│ 0.01 │     ●  θ_{d/4}      (medium freq: sentence-level)               │
│      │      ╲                                                            │
│0.001 │       ╲                                                           │
│      │        ╲                                                          │
│0.0001│         ●  θ_{d/2-1} (low freq: document-level)                 │
│      └────────────────────────────────────────────→ Dimension pair j    │
│           0       d/4      d/2-1                                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WAVELENGTH λ_j = 2π / θ_j:                                             │
│                                                                          │
│  • j = 0:      λ ≈ 6 positions     → distinguishes adjacent tokens     │
│  • j = d/4:    λ ≈ 600 positions   → paragraph-level patterns          │
│  • j = d/2-1:  λ ≈ 60,000 positions → document-level patterns          │
│                                                                          │
│  Together: unique position "fingerprint" at ALL scales                 │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part V: Implementation

Complexity Analysis

ApproachTime ComplexitySpace Complexity
Naive matrix multiplyO(d2)O(d^2) per vectorO(d2)O(d^2) for matrix
Efficient (element-wise)O(d)O(d) per vectorO(d)O(d) for sin/cos
PrecomputedO(d)O(d) per vectorO(Ld)O(L \cdot d) for all positions

The block-diagonal structure enables O(d)O(d) computation—no full matrix multiply needed.

Core Implementation

Frequency computation (rope.py:1-20):

Python
import torch
import math

def compute_rope_frequencies(
    dim: int,
    max_seq_len: int,
    base: float = 10000.0,
    device: torch.device = None
) -> torch.Tensor:
    """
    Precompute RoPE frequencies for all positions.

    θ_j = 1 / base^(2j/dim) for j ∈ [0, dim/2)

    Returns:
        freqs: (max_seq_len, dim/2) - angles m·θ_j for each position m
    """
    # Compute θ_j = base^(-2j/dim)
    j = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
    theta = 1.0 / (base ** (j / dim))  # Shape: (dim/2,)

    # Position indices
    m = torch.arange(max_seq_len, dtype=torch.float32, device=device)

    # Outer product: angles[m, j] = m * θ_j
    angles = torch.outer(m, theta)  # Shape: (max_seq_len, dim/2)

    return angles

Complex implementation (rope.py:22-60):

Python
def precompute_freqs_cis(
    dim: int,
    max_seq_len: int,
    base: float = 10000.0,
    device: torch.device = None
) -> torch.Tensor:
    """
    Precompute complex exponentials e^(i·m·θ_j).

    This is the Llama-style implementation using complex numbers.

    Returns:
        freqs_cis: Complex tensor (max_seq_len, dim/2)
                   freqs_cis[m, j] = cos(m·θ_j) + i·sin(m·θ_j)
    """
    angles = compute_rope_frequencies(dim, max_seq_len, base, device)

    # Convert to complex: e^(i·angle) = cos(angle) + i·sin(angle)
    freqs_cis = torch.polar(torch.ones_like(angles), angles)

    return freqs_cis


def apply_rotary_emb_complex(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
    start_pos: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply RoPE using complex multiplication.

    Args:
        xq: Query tensor (batch, seq_len, n_heads, head_dim)
        xk: Key tensor (batch, seq_len, n_kv_heads, head_dim)
        freqs_cis: Precomputed frequencies (max_seq_len, head_dim/2)
        start_pos: Starting position (for KV cache)

    Returns:
        Rotated (query, key) tensors, same shapes as input
    """
    batch, seq_len, n_heads, head_dim = xq.shape
    n_kv_heads = xk.shape[2]

    # Reshape to complex: (batch, seq, heads, dim/2, 2) -> complex (batch, seq, heads, dim/2)
    xq_complex = torch.view_as_complex(xq.float().reshape(batch, seq_len, n_heads, -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(batch, seq_len, n_kv_heads, -1, 2))

    # Get frequencies for this sequence
    freqs = freqs_cis[start_pos : start_pos + seq_len]  # (seq_len, dim/2)
    freqs = freqs.unsqueeze(0).unsqueeze(2)  # (1, seq_len, 1, dim/2)

    # Complex multiplication = rotation
    xq_rot = xq_complex * freqs
    xk_rot = xk_complex * freqs

    # Back to real
    xq_out = torch.view_as_real(xq_rot).reshape(batch, seq_len, n_heads, head_dim)
    xk_out = torch.view_as_real(xk_rot).reshape(batch, seq_len, n_kv_heads, head_dim)

    return xq_out.type_as(xq), xk_out.type_as(xk)

Real-number implementation (rope.py:62-110):

Python
def precompute_cos_sin(
    dim: int,
    max_seq_len: int,
    base: float = 10000.0,
    device: torch.device = None
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Precompute cos and sin for real-number RoPE.

    Returns:
        cos: (max_seq_len, dim/2) - cos(m·θ_j)
        sin: (max_seq_len, dim/2) - sin(m·θ_j)
    """
    angles = compute_rope_frequencies(dim, max_seq_len, base, device)
    return torch.cos(angles), torch.sin(angles)


def apply_rotary_emb_real(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    start_pos: int = 0
) -> torch.Tensor:
    """
    Apply RoPE using only real arithmetic.

    Implements the rotation formula:
        x̃_{2j}   = x_{2j}·cos(mθ_j) - x_{2j+1}·sin(mθ_j)
        x̃_{2j+1} = x_{2j}·sin(mθ_j) + x_{2j+1}·cos(mθ_j)

    Args:
        x: Input tensor (..., head_dim)
        cos: Cosine values (max_seq_len, head_dim/2)
        sin: Sine values (max_seq_len, head_dim/2)
        start_pos: Starting position

    Returns:
        Rotated tensor, same shape as input
    """
    seq_len = x.shape[1]

    # Get cos/sin for this sequence
    cos = cos[start_pos : start_pos + seq_len]  # (seq_len, dim/2)
    sin = sin[start_pos : start_pos + seq_len]

    # Reshape for broadcasting: (1, seq_len, 1, dim/2)
    cos = cos.unsqueeze(0).unsqueeze(2)
    sin = sin.unsqueeze(0).unsqueeze(2)

    # Split into even/odd pairs
    x_even = x[..., 0::2]  # x_0, x_2, x_4, ...
    x_odd = x[..., 1::2]   # x_1, x_3, x_5, ...

    # Apply 2D rotation to each pair
    x_even_rot = x_even * cos - x_odd * sin
    x_odd_rot = x_even * sin + x_odd * cos

    # Interleave back
    out = torch.stack([x_even_rot, x_odd_rot], dim=-1)
    return out.reshape(x.shape)

Full Attention Module

Complete attention with RoPE (attention.py:1-80):

Python
class RoPEMultiHeadAttention(torch.nn.Module):
    """
    Multi-head attention with Rotary Position Embeddings.

    Attention computation:
        Q_rot, K_rot = RoPE(Q, pos), RoPE(K, pos)
        Attention = softmax(Q_rot · K_rot^T / √d) · V

    Note: V is NOT rotated - position affects attention weights, not values.
    """

    def __init__(
        self,
        dim: int,
        n_heads: int,
        n_kv_heads: int | None = None,
        max_seq_len: int = 4096,
        rope_base: float = 10000.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads or n_heads
        self.head_dim = dim // n_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)

        # Check dimensions
        assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
        assert n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"

        # Projections (no bias, following modern LLM convention)
        self.wq = torch.nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.wk = torch.nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = torch.nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = torch.nn.Linear(n_heads * self.head_dim, dim, bias=False)

        self.dropout = torch.nn.Dropout(dropout)

        # Precompute RoPE frequencies
        self.register_buffer(
            "freqs_cis",
            precompute_freqs_cis(self.head_dim, max_seq_len, rope_base)
        )

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int = 0,
        mask: torch.Tensor | None = None,
        kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        batch, seq_len, _ = x.shape

        # Project to Q, K, V
        q = self.wq(x).view(batch, seq_len, self.n_heads, self.head_dim)
        k = self.wk(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)

        # Apply RoPE to Q and K (NOT V!)
        q, k = apply_rotary_emb_complex(q, k, self.freqs_cis, start_pos)

        # KV cache handling
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k], dim=1)
            v = torch.cat([v_cache, v], dim=1)

        # GQA: expand KV heads to match query heads
        if self.n_kv_heads < self.n_heads:
            n_rep = self.n_heads // self.n_kv_heads
            k = k.unsqueeze(3).expand(-1, -1, -1, n_rep, -1).reshape(batch, -1, self.n_heads, self.head_dim)
            v = v.unsqueeze(3).expand(-1, -1, -1, n_rep, -1).reshape(batch, -1, self.n_heads, self.head_dim)

        # Transpose for attention: (batch, n_heads, seq_len, head_dim)
        q, k, v = [t.transpose(1, 2) for t in (q, k, v)]

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores + mask

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(batch, seq_len, -1)

        return self.wo(out), (k, v)
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ROPE IN ATTENTION: DATA FLOW                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  INPUT: x (batch, seq_len, dim)                                         │
│           │                                                              │
│           ▼                                                              │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                    Linear Projections                            │   │
│  │   Q = x @ W_Q     K = x @ W_K     V = x @ W_V                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│           │              │              │                               │
│           ▼              ▼              │                               │
│  ┌─────────────────────────────────┐    │                               │
│  │         Apply RoPE              │    │                               │
│  │  Q_rot = R(m·θ) · Q            │    │  V is NOT rotated!            │
│  │  K_rot = R(n·θ) · K            │    │  (content unchanged)          │
│  └─────────────────────────────────┘    │                               │
│           │              │              │                               │
│           ▼              ▼              ▼                               │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │              Scaled Dot-Product Attention                        │   │
│  │                                                                  │   │
│  │  scores = Q_rot · K_rot^T / √d                                  │   │
│  │           └── depends on (m-n)·θ = RELATIVE POSITION            │   │
│  │                                                                  │   │
│  │  attn = softmax(mask(scores))                                   │   │
│  │  output = attn · V                                              │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│           │                                                              │
│           ▼                                                              │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │  output = output @ W_O                                           │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│           │                                                              │
│           ▼                                                              │
│  OUTPUT: (batch, seq_len, dim)                                          │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VI: Context Extension Methods

RoPE enables sophisticated context extension through mathematical manipulation of rotation frequencies.

Position Interpolation (PI)

Paper: "Extending Context Window of Large Language Models via Positional Interpolation" (Chen et al., 2023)

Problem: Model trained with max position LL. At position m>Lm > L, angles mθjm\theta_j exceed training distribution.

Solution: Scale positions to fit within [0,L][0, L]:

m=mLLm' = \frac{m \cdot L}{L'}

Modified RoPE:

q~m=RΘ,mL/Ldq\tilde{q}_m = \mathbf{R}_{\Theta, m \cdot L/L'}^d q

Effect: All rotation angles scaled by L/LL/L':

mθjmLLθj=mθjm\theta_j \rightarrow \frac{mL}{L'}\theta_j = m \cdot \theta_j'

where θj=θj(L/L)\theta_j' = \theta_j \cdot (L/L').

Implementation:

Python
def apply_position_interpolation(
    freqs_cis: torch.Tensor,
    original_max_len: int,
    target_max_len: int
) -> torch.Tensor:
    """
    Position Interpolation: scale all positions by L/L'.

    All frequencies are uniformly scaled down.
    """
    scale = original_max_len / target_max_len
    # Recompute with scaled positions
    # Equivalent to using freqs_cis[int(pos * scale)]
    return freqs_cis  # Actual implementation indexes differently

Limitation: High-frequency dimensions (large θj\theta_j) are scaled most, degrading local position discrimination.

NTK-Aware Interpolation

Key insight: Instead of uniform scaling, increase base bb to lower frequencies non-uniformly.

Modified frequency:

θj=(bα)2j/d=α2j/db2j/d=α2j/dθj\theta_j' = (b \cdot \alpha)^{-2j/d} = \alpha^{-2j/d} \cdot b^{-2j/d} = \alpha^{-2j/d} \cdot \theta_j

Effect by dimension:

DimensionScaling factorEffect
j=0j = 0α0=1\alpha^0 = 1Unchanged (local patterns preserved)
j=d/4j = d/4α0.5\alpha^{-0.5}Moderately scaled
j=d/21j = d/2-1α1\alpha^{-1}Fully scaled (global patterns extended)

Choosing α\alpha: To extend from context LL to LL':

α=(LL)d/(d2)\alpha = \left(\frac{L'}{L}\right)^{d/(d-2)}

Implementation:

Python
def compute_ntk_frequencies(
    dim: int,
    max_seq_len: int,
    base: float = 10000.0,
    original_max_len: int = 4096,
    device: torch.device = None
) -> torch.Tensor:
    """
    NTK-aware interpolation: scale base instead of positions.

    High frequencies (local patterns) preserved.
    Low frequencies (global patterns) scaled.
    """
    scale = max_seq_len / original_max_len

    # α = scale^(d/(d-2))
    alpha = scale ** (dim / (dim - 2))

    # Modified base
    ntk_base = base * alpha

    return compute_rope_frequencies(dim, max_seq_len, ntk_base, device)

YaRN (Yet another RoPE extensioN)

Paper: "YaRN: Efficient Context Window Extension of Large Language Models" (Peng et al., 2023)

YaRN combines multiple techniques:

1. Frequency Partitioning

Divide dimensions into three groups based on wavelength λj=2π/θj\lambda_j = 2\pi/\theta_j:

GroupConditionInterpolation
High-frequencyλj<L\lambda_j < LNone (preserve local)
Medium-frequencyLλjLL \leq \lambda_j \leq L'Partial (ramp)
Low-frequencyλj>L\lambda_j > L'Full PI

Interpolation factor (smooth ramp from 0 = no interpolation to 1 = full interpolation):

γj={0if λj<L (high freq: preserve)λjLLLif LλjL (medium: ramp)1if λj>L (low freq: full PI)\gamma_j = \begin{cases} 0 & \text{if } \lambda_j < L \text{ (high freq: preserve)} \\ \frac{\lambda_j - L}{L' - L} & \text{if } L \leq \lambda_j \leq L' \text{ (medium: ramp)} \\ 1 & \text{if } \lambda_j > L' \text{ (low freq: full PI)} \end{cases}

Modified frequency:

θj=θj(1γj(1LL))\theta_j' = \theta_j \cdot \left(1 - \gamma_j \cdot \left(1 - \frac{L}{L'}\right)\right)

2. Attention Temperature Scaling

Scale attention logits to compensate for changed score distributions:

score=qkdt\text{score} = \frac{q^\top k}{\sqrt{d} \cdot t}

where temperature:

t=0.1ln(s)+1,s=LLt = 0.1 \cdot \ln(s) + 1, \quad s = \frac{L'}{L}

Implementation:

Python
def yarn_find_correction_range(
    dim: int,
    base: float,
    original_max_len: int
) -> tuple[float, float]:
    """
    Find dimension range for partial interpolation.

    Returns (low_dim, high_dim) where:
    - dims < low_dim: no interpolation
    - low_dim <= dims <= high_dim: partial interpolation
    - dims > high_dim: full interpolation
    """
    # λ_j = 2π · base^(2j/dim)
    # Find j where λ_j = original_max_len

    # For β rotations, wavelength threshold = original_max_len / β
    # High freq cutoff: β = 32 (many rotations)
    # Low freq cutoff: β = 1 (one rotation)

    def find_dim(num_rotations):
        # Solve: 2π · base^(2j/dim) = original_max_len / num_rotations
        return (dim * math.log(original_max_len / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

    low_dim = find_dim(32)   # High-frequency boundary
    high_dim = find_dim(1)   # Low-frequency boundary

    return max(0, low_dim), min(dim // 2 - 1, high_dim)


def compute_yarn_frequencies(
    dim: int,
    max_seq_len: int,
    base: float = 10000.0,
    original_max_len: int = 4096,
    device: torch.device = None
) -> tuple[torch.Tensor, float]:
    """
    YaRN: frequency partitioning + attention scaling.

    Returns:
        freqs: Modified frequency tensor
        mscale: Attention temperature scale
    """
    scale = max_seq_len / original_max_len

    if scale <= 1:
        freqs = compute_rope_frequencies(dim, max_seq_len, base, device)
        return freqs, 1.0

    # Find interpolation boundaries
    low_dim, high_dim = yarn_find_correction_range(dim, base, original_max_len)

    # Base frequencies
    j = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
    theta = 1.0 / (base ** (j / dim))

    # Compute interpolation ramp
    ramp = torch.clamp((j / 2 - low_dim) / (high_dim - low_dim), 0, 1)

    # Interpolated scale: 1 for low j, 1/scale for high j
    freq_scale = (1 - ramp) + ramp / scale

    # Apply scaling
    theta_scaled = theta * freq_scale

    # Position angles
    m = torch.arange(max_seq_len, dtype=torch.float32, device=device)
    freqs = torch.outer(m, theta_scaled)

    # Attention temperature
    mscale = 0.1 * math.log(scale) + 1.0

    return freqs, mscale

LongRoPE

Key innovations:

  1. Search-based factors: Instead of analytical formulas, search for optimal per-dimension interpolation factors λj\lambda_j by minimizing perplexity on long-context data.

  2. Progressive extension: Extend in stages (4K → 128K → 256K → 2M), fine-tuning at each stage.

  3. Short context recovery: Use original RoPE for short sequences, scaled RoPE for long—prevents short-context degradation.

Comparison

MethodHigh-freqLow-freqTrainingComplexity
Position InterpolationDegradedScaled1000 stepsSimple
NTK-awarePreservedScaled1000 stepsSimple
YaRNPreservedScaled400 stepsMedium
LongRoPEOptimizedOptimizedSearch + fine-tuneComplex

Part VII: Mathematical Properties

Norm Preservation

RoPE preserves vector norms:

x~=x\|\tilde{x}\| = \|x\|

Proof: Each 2×2 rotation block preserves norm (orthogonal matrix). Blocks are independent:

x~2=j=0d/21x~2j:2j+22=j=0d/21x2j:2j+22=x2\|\tilde{x}\|^2 = \sum_{j=0}^{d/2-1} \|\tilde{x}_{2j:2j+2}\|^2 = \sum_{j=0}^{d/2-1} \|x_{2j:2j+2}\|^2 = \|x\|^2 \quad \blacksquare

Linearity

RoPE is linear in content:

RΘ,md(αx+βy)=αRΘ,mdx+βRΘ,mdy\mathbf{R}^d_{\Theta,m}(\alpha x + \beta y) = \alpha \mathbf{R}^d_{\Theta,m} x + \beta \mathbf{R}^d_{\Theta,m} y

Follows from matrix multiplication linearity.

Position Additivity

Rotation angles add for sequential positions:

RΘ,mdRΘ,nd=RΘ,m+nd\mathbf{R}^d_{\Theta,m} \mathbf{R}^d_{\Theta,n} = \mathbf{R}^d_{\Theta,m+n}

Follows from R(α)R(β)=R(α+β)R(\alpha)R(\beta) = R(\alpha + \beta) for each block.

Implication: Relative position (mn)(m-n) emerges naturally from rotation group structure.

Dimension Independence

Different pairs are independent:

(x~2j,x~2j+1)=R(mθj)(x2j,x2j+1)(\tilde{x}_{2j}, \tilde{x}_{2j+1}) = R(m\theta_j)(x_{2j}, x_{2j+1})

Each pair's rotation depends only on that pair's input, not other dimensions. Model can learn different patterns at different frequency scales.


Part VIII: Comparison with Other Methods

RoPE vs Sinusoidal

AspectSinusoidalRoPE
OperationAdditive: x+PE(m)x + \text{PE}(m)Multiplicative: R(m)xR(m) \cdot x
Relative positionImplicit (model must learn)Explicit (mathematical property)
Content-position mixingMixed in same vectorSeparated (rotation vs. magnitude)
Length extrapolationPoorScalable with PI/YaRN/etc.

RoPE vs ALiBi

AspectALiBiRoPE
Where appliedAttention scores onlyQ and K embeddings
Formula$\text{score} - mi-j
Content interactionNone (bias is fixed)Content modulates position
ExtrapolationExcellent (linear extends)Good with scaling methods
ExpressivenessLinear decay onlyArbitrary learned functions

RoPE vs Learned

AspectLearnedRoPE
ParametersO(Ld)O(L \cdot d)Zero
Max lengthFixed (embedding table size)Extensible
Relative positionMust be learned implicitlyBuilt-in mathematical property
FlexibilityCan learn any patternConstrained to rotation

Part IX: Advanced Topics and Variations

2D RoPE for Vision Transformers

Vision transformers process images as sequences of patches. Each patch has a 2D position (h,w)(h, w) rather than a 1D position. RoPE can be extended to 2D by splitting dimensions between row and column encoding.

Approach: Allocate half the dimensions to row position, half to column position.

RoPE2D(x,h,w)=(Rhd/200Rwd/2)x\text{RoPE}_{2D}(x, h, w) = \begin{pmatrix} R_{h}^{d/2} & 0 \\ 0 & R_{w}^{d/2} \end{pmatrix} x

where Rhd/2R_h^{d/2} rotates the first d/2d/2 dimensions by row position, and Rwd/2R_w^{d/2} rotates the remaining dimensions by column position.

Python
def precompute_freqs_2d(
    dim: int,
    height: int,
    width: int,
    base: float = 10000.0
) -> torch.Tensor:
    """
    Precompute 2D RoPE frequencies for vision transformers.

    Returns: (height * width, dim/2) complex tensor
    """
    # Split dimensions: half for rows, half for columns
    half_dim = dim // 2

    # Frequencies for each dimension
    theta_h = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim))
    theta_w = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim))

    # Position grids
    h_pos = torch.arange(height)
    w_pos = torch.arange(width)

    # Compute angles
    angles_h = torch.outer(h_pos, theta_h)  # (height, half_dim/2)
    angles_w = torch.outer(w_pos, theta_w)  # (width, half_dim/2)

    # Create 2D grid of frequencies
    # For each (h, w) position, concatenate row and column frequencies
    freqs_h = torch.polar(torch.ones_like(angles_h), angles_h)  # (height, half_dim/2)
    freqs_w = torch.polar(torch.ones_like(angles_w), angles_w)  # (width, half_dim/2)

    # Expand to (height, width, half_dim/2) then combine
    freqs_h = freqs_h.unsqueeze(1).expand(-1, width, -1)  # (height, width, half_dim/2)
    freqs_w = freqs_w.unsqueeze(0).expand(height, -1, -1)  # (height, width, half_dim/2)

    # Concatenate row and column frequencies
    freqs_2d = torch.cat([freqs_h, freqs_w], dim=-1)  # (height, width, half_dim)

    # Flatten to (height * width, half_dim)
    return freqs_2d.reshape(height * width, -1)


def apply_rope_2d(
    x: torch.Tensor,  # (batch, num_patches, heads, head_dim)
    freqs_2d: torch.Tensor,  # (num_patches, head_dim/2)
    patch_height: int,
    patch_width: int
) -> torch.Tensor:
    """Apply 2D RoPE to image patch embeddings."""
    batch, num_patches, heads, dim = x.shape

    # Reshape to complex
    x_complex = torch.view_as_complex(
        x.float().reshape(batch, num_patches, heads, dim // 2, 2)
    )

    # Apply rotation
    freqs = freqs_2d.unsqueeze(0).unsqueeze(2)  # (1, num_patches, 1, dim/2)
    x_rot = x_complex * freqs

    # Back to real
    return torch.view_as_real(x_rot).reshape(batch, num_patches, heads, dim).type_as(x)

Key insight: 2D RoPE preserves the relative position property in both dimensions independently:

  • Attention between patches at (h1,w1)(h_1, w_1) and (h2,w2)(h_2, w_2) depends on (h1h2)(h_1-h_2) and (w1w2)(w_1-w_2)

3D RoPE for Video

For video (time × height × width), extend to 3D by splitting dimensions three ways:

RoPE3D(x,t,h,w)=(Rtd/3000Rhd/3000Rwd/3)x\text{RoPE}_{3D}(x, t, h, w) = \begin{pmatrix} R_{t}^{d/3} & 0 & 0 \\ 0 & R_{h}^{d/3} & 0 \\ 0 & 0 & R_{w}^{d/3} \end{pmatrix} x

This enables relative position encoding in all three dimensions: temporal, vertical, and horizontal.

RoPE with Multi-Head Latent Attention (MLA)

DeepSeek-V2 and V3 use Multi-Head Latent Attention (MLA) with RoPE. MLA compresses KV cache by projecting keys and values through a low-rank bottleneck.

The challenge: In standard RoPE, we rotate QQ and KK by position. In MLA, keys are compressed:

Kcompressed=KWdownRn×dlatentK_{\text{compressed}} = K W_{\text{down}} \in \mathbb{R}^{n \times d_{\text{latent}}}

If we apply RoPE to the full key and then compress, position information is mixed with content in the compressed representation.

DeepSeek's solution: Split the query/key dimensions into position-dependent and position-independent parts.

Python
class MLAWithRoPE(nn.Module):
    """
    Multi-Head Latent Attention with RoPE.

    Key insight: Separate RoPE dimensions from compressed dimensions.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_latent: int,  # Compressed KV dimension
        d_rope: int,    # Dimensions for RoPE (not compressed)
        max_seq_len: int = 4096
    ):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.d_latent = d_latent
        self.d_rope = d_rope
        self.d_nope = self.head_dim - d_rope  # Non-RoPE dimensions

        # Query projection
        self.wq = nn.Linear(d_model, n_heads * self.head_dim)

        # Key/Value projections with latent compression
        self.wkv_down = nn.Linear(d_model, d_latent)  # Compress
        self.wk_up = nn.Linear(d_latent, n_heads * self.d_nope)  # Expand K (non-RoPE part)
        self.wk_rope = nn.Linear(d_model, n_heads * d_rope)  # K RoPE part (not compressed)
        self.wv_up = nn.Linear(d_latent, n_heads * self.head_dim)  # Expand V

        self.wo = nn.Linear(n_heads * self.head_dim, d_model)

        # RoPE frequencies
        self.register_buffer(
            "freqs_cis",
            precompute_freqs_cis(d_rope, max_seq_len)
        )

    def forward(self, x: torch.Tensor, start_pos: int = 0):
        batch, seq_len, _ = x.shape

        # Query: full projection, apply RoPE to first d_rope dims
        q = self.wq(x).view(batch, seq_len, self.n_heads, self.head_dim)
        q_rope = q[..., :self.d_rope]
        q_nope = q[..., self.d_rope:]

        # Apply RoPE to query's RoPE dimensions
        freqs = self.freqs_cis[start_pos : start_pos + seq_len]
        q_rope = apply_rope(q_rope, freqs)
        q = torch.cat([q_rope, q_nope], dim=-1)

        # Key: compressed part (no RoPE) + RoPE part (not compressed)
        kv_latent = self.wkv_down(x)  # Compress
        k_nope = self.wk_up(kv_latent).view(batch, seq_len, self.n_heads, self.d_nope)
        k_rope = self.wk_rope(x).view(batch, seq_len, self.n_heads, self.d_rope)

        # Apply RoPE to key's RoPE dimensions
        k_rope = apply_rope(k_rope, freqs)
        k = torch.cat([k_rope, k_nope], dim=-1)

        # Value: from compressed representation
        v = self.wv_up(kv_latent).view(batch, seq_len, self.n_heads, self.head_dim)

        # Standard attention from here...
        # (KV cache stores kv_latent and k_rope separately for efficiency)
        return self.attention(q, k, v)

Why this works:

  1. RoPE dimensions carry position information, not compressed
  2. Non-RoPE dimensions carry content, can be compressed
  3. Best of both: position accuracy + memory efficiency

Partial RoPE

Some models apply RoPE to only a subset of dimensions, leaving others position-agnostic.

Motivation: Not all attention patterns need position. Some heads might learn position-independent patterns (e.g., "always attend to [SEP] token").

Python
def apply_partial_rope(
    x: torch.Tensor,  # (batch, seq, heads, head_dim)
    freqs: torch.Tensor,
    rope_dims: int  # How many dimensions to rotate
) -> torch.Tensor:
    """
    Apply RoPE to only the first rope_dims dimensions.

    Remaining dimensions are position-agnostic.
    """
    # Split dimensions
    x_rope = x[..., :rope_dims]
    x_pass = x[..., rope_dims:]

    # Apply RoPE to subset
    x_rope_rot = apply_rope(x_rope, freqs)

    # Recombine
    return torch.cat([x_rope_rot, x_pass], dim=-1)

Used by: Some Llama variants, CodeLlama (different RoPE scaling for code vs text dimensions).


Part X: Visualization and Intuition

Visualizing Rotation in 2D

Consider a single dimension pair. The query and key vectors rotate in a 2D plane as position changes.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    VISUALIZING ROPE ROTATION                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Consider dimension pair 0 (highest frequency, θ₀ = 1):                 │
│                                                                          │
│                         Position 0                                       │
│                              │                                           │
│                              ▼                                           │
│                           ──────→ q (no rotation)                       │
│                                                                          │
│                         Position 1                                       │
│                              │                                           │
│                              ▼                                           │
│                           ╱                                              │
│                         ╱   → q rotated by θ = 1 radian ≈ 57°           │
│                                                                          │
│                         Position 2                                       │
│                              │                                           │
│                              ▼                                           │
│                         │                                                │
│                         ↓     q rotated by 2θ ≈ 114°                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  When computing attention between positions m and n:                    │
│                                                                          │
│  Position m=0     Position n=2                                          │
│       │                │                                                 │
│       ▼                ▼                                                 │
│    ──────→          │                                                   │
│       q_0           ↓ k_2                                               │
│                                                                          │
│  Dot product depends on ANGLE BETWEEN them = 2θ                        │
│  This is (n - m) × θ = relative position × frequency                   │
│                                                                          │
│  Same pattern holds for m=100, n=102: angle is still 2θ               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Frequency Spectrum Visualization

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ROPE FREQUENCY SPECTRUM                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Position    0    10    20    30    40    50    60    70    80          │
│             │     │     │     │     │     │     │     │     │           │
│  Dim 0:    ─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮ ╭─╮        │
│  (θ=1)      ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯ ╰─╯          │
│             Fast oscillation (~6 positions per cycle)                   │
│                                                                          │
│  Dim 16:   ───────╮     ╭───────╮     ╭───────╮     ╭───────╮          │
│  (θ≈0.06)         ╰─────╯       ╰─────╯       ╰─────╯                   │
│             Medium oscillation (~100 positions per cycle)               │
│                                                                          │
│  Dim 32:   ─────────────────────╮                   ╭─────────          │
│  (θ≈0.003)                      ╰───────────────────╯                   │
│             Slow oscillation (~2000 positions per cycle)                │
│                                                                          │
│  Dim 63:   ───────────────────────────────────────────────────          │
│  (θ≈0.0001)                                                             │
│             Very slow (~60000 positions per cycle)                      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMBINED EFFECT: Each position has a unique "fingerprint"             │
│                                                                          │
│  Position 0:  [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, ...]                     │
│  Position 1:  [0.54, 0.84, 0.998, 0.06, 0.9999, 0.006, ...]           │
│  Position 10: [-0.84, 0.54, 0.82, 0.57, 0.997, 0.06, ...]             │
│                                                                          │
│  High freq dims: change rapidly, distinguish neighbors                 │
│  Low freq dims: change slowly, distinguish distant positions           │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part XI: Common Bugs and Debugging

Bug 1: Wrong Dimension Pairing

Symptom: Model trains but performance is poor, especially on long sequences.

Cause: Pairing wrong dimensions (e.g., [0,1], [1,2], [2,3] instead of [0,1], [2,3], [4,5]).

Python
# WRONG: Overlapping pairs
x_pairs_wrong = x.unfold(-1, 2, 1)  # Creates overlapping windows!

# CORRECT: Non-overlapping pairs
x_pairs = x.reshape(..., dim // 2, 2)  # Proper pairing

Bug 2: Forgetting to Apply RoPE to Keys

Symptom: Model ignores position entirely.

Cause: Only rotating queries, not keys.

Python
# WRONG
q = apply_rope(q, freqs, positions)
# k is not rotated!

# CORRECT
q = apply_rope(q, freqs, positions)
k = apply_rope(k, freqs, positions)  # Both must be rotated!

Bug 3: Wrong Position Indices with KV Cache

Symptom: During inference with KV cache, model outputs degrade after a few tokens.

Cause: Using wrong position indices for new tokens.

Python
# WRONG: Always using positions [0, 1, 2, ...] for new tokens
positions = torch.arange(seq_len)
q = apply_rope(q, freqs, positions)

# CORRECT: Use actual positions accounting for cache
positions = torch.arange(start_pos, start_pos + seq_len)
q = apply_rope(q, freqs[start_pos:start_pos + seq_len])

Bug 4: Rotating Values (V)

Symptom: Strange outputs, model doesn't learn well.

Cause: Applying RoPE to values in addition to queries and keys.

Python
# WRONG
q = apply_rope(q, freqs, positions)
k = apply_rope(k, freqs, positions)
v = apply_rope(v, freqs, positions)  # NO! V should not be rotated

# CORRECT
q = apply_rope(q, freqs, positions)
k = apply_rope(k, freqs, positions)
# v unchanged - position affects attention weights, not values

Bug 5: Base Frequency Mismatch

Symptom: Loading a pretrained model gives garbage outputs.

Cause: Using different RoPE base than the model was trained with.

Python
# Model was trained with base=500000 (Llama 3)
# But you're using base=10000 (default)

# WRONG
freqs = precompute_freqs(dim, max_len, base=10000)

# CORRECT: Match the model's training configuration
freqs = precompute_freqs(dim, max_len, base=500000)  # Check model config!

Bug 6: Context Extension Without Scaling

Symptom: Model works fine up to training length, then output quality drops sharply.

Cause: Using positions beyond training range without scaling.

Python
# Model trained on 4096 tokens, now using 8192

# WRONG: Raw RoPE at extended positions
freqs = precompute_freqs(dim, 8192, base=10000)  # Angles exceed training distribution

# CORRECT: Apply scaling
freqs = compute_yarn_freqs(dim, 8192, base=10000, original_max=4096)

Debugging Checklist

CheckHow to Verify
Dimension pairingPrint shape after reshape: should be (..., dim/2, 2)
Q and K both rotatedAdd print statements or breakpoints in attention
Position indicesPrint positions tensor, verify matches actual token positions
V unchangedVerify V is not passed through any RoPE function
Base frequencyCheck model config file for rope_theta or rope_base
Context scalingCompare perplexity at training vs extended lengths

Part XII: LongRoPE2 and Recent Advances (2024-2025)

LongRoPE2

LongRoPE2 improves on LongRoPE with more sophisticated search for optimal interpolation factors.

Key innovations:

  1. Needle-driven evaluation: Instead of just perplexity, use "needle-in-haystack" tasks to evaluate position encoding quality. Can the model find specific information at various positions?

  2. Evolutionary search: Use genetic algorithms to search the space of per-dimension interpolation factors, evolving populations of factor configurations.

  3. Multi-objective optimization: Balance perplexity, needle accuracy, and computational cost.

Python
def longrope2_search(
    model,
    target_length: int,
    original_length: int,
    dim: int,
    population_size: int = 50,
    generations: int = 100,
    needle_weight: float = 0.7,
    perplexity_weight: float = 0.3
):
    """
    Evolutionary search for optimal RoPE scaling factors.

    Returns: (dim/2,) tensor of per-dimension scaling factors
    """
    # Initialize population with random factors around 1.0
    population = torch.randn(population_size, dim // 2) * 0.1 + 1.0

    for gen in range(generations):
        # Evaluate each individual
        fitness = []
        for factors in population:
            # Apply factors to RoPE
            scaled_freqs = apply_longrope_factors(base_freqs, factors)

            # Evaluate on needle tasks (can the model find info at various positions?)
            needle_score = evaluate_needle_in_haystack(model, scaled_freqs, target_length)

            # Evaluate perplexity on validation set
            ppl = evaluate_perplexity(model, scaled_freqs, target_length)

            # Combined fitness
            fit = needle_weight * needle_score - perplexity_weight * math.log(ppl)
            fitness.append(fit)

        # Selection, crossover, mutation (standard genetic algorithm)
        population = evolve(population, fitness)

    # Return best individual
    return population[torch.argmax(torch.tensor(fitness))]

Other Recent Advances

Dynamic RoPE (2024): Adjust RoPE frequencies dynamically based on input content, not just position.

Learned frequency schedules: Instead of the fixed geometric sequence θj=b2j/d\theta_j = b^{-2j/d}, learn the frequencies during training.

Hybrid approaches: Combine RoPE with ALiBi-style biases for the best of both worlds.


Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles

LLMsML Engineering

Positional Embeddings: How Transformers Understand Word Order

A comprehensive deep dive into positional embeddings—how transformers encode sequence order. From sinusoidal encodings to learned embeddings, relative positions to ALiBi, understand the evolution that led to modern approaches like RoPE.

12 min read
LLMsML Engineering

Context Extension: How LLMs Scale Beyond Training Length

A comprehensive deep dive into context extension techniques—how models trained on 4K tokens extrapolate to 128K+. Understand RoPE scaling, Position Interpolation, NTK-aware scaling, YaRN, and the mathematics of long-context LLMs.

5 min read
EducationLLMs

Transformer Architecture: A Complete Deep Dive

A comprehensive exploration of the transformer architecture—from embedding layers through attention and feed-forward networks to the output head. Understand why decoder-only models dominate, how residual connections enable deep networks, and the engineering decisions behind GPT, Llama, and modern LLMs.

30 min read
LLMsML Engineering

Attention Mechanisms: From Self-Attention to FlashAttention

A comprehensive deep dive into attention mechanisms—the core innovation powering modern LLMs. From the intuition behind self-attention to the engineering of FlashAttention, understand how transformers actually work.

7 min read
LLMsTraining

nanoGPT: Andrej Karpathy's Minimal GPT Training Framework

A comprehensive, equation-complete analysis of nanoGPT—Andrej Karpathy's influential minimal GPT implementation. Deep dive into the ~300-line model definition (model.py), training loop (train.py), Flash Attention, weight initialization, and the mathematical foundations behind every component.

3 min read