Skip to main content
Back to Blog

Positional Embeddings: How Transformers Understand Word Order

A comprehensive deep dive into positional embeddings—how transformers encode sequence order. From sinusoidal encodings to learned embeddings, relative positions to ALiBi, understand the evolution that led to modern approaches like RoPE.

12 min read
Share:

Why Position Matters

Transformers have a fundamental problem: they're blind to order. The self-attention mechanism that makes transformers so powerful—allowing any token to attend to any other token—is completely permutation-equivariant. If you shuffle the input tokens, the attention outputs shuffle the same way. The model cannot tell that "dog bites man" differs from "man bites dog."

This is a profound limitation. Word order carries essential meaning in language. "The cat sat on the mat" and "The mat sat on the cat" have the same tokens but completely different meanings. Without position information, a transformer would treat them identically.

Positional embeddings solve this problem by injecting position information into the model. But how you encode position has far-reaching consequences for what the model can learn and how it generalizes to different sequence lengths.

This post traces the evolution of positional embeddings from the original transformer to modern approaches. Understanding this evolution reveals why certain methods work better than others and prepares you for RoPE (Rotary Position Embeddings), the method that dominates modern LLMs.


Historical Timeline

The evolution of positional embeddings reflects our deepening understanding of what transformers need to know about sequence order.

YearMethodPaperKey Innovation
2017Sinusoidal"Attention Is All You Need" (Vaswani et al.)Fixed frequencies, theoretically infinite extrapolation
2018LearnedBERT, GPT-2Let the model learn position patterns
2019Relative (Transformer-XL)"Transformer-XL" (Dai et al.)Directly encode relative distance
2020T5 Relative Bias"Exploring the Limits of Transfer Learning" (Raffel et al.)Scalar bias with logarithmic bucketing
2021RoPE"RoFormer" (Su et al.)Rotation encodes relative position mathematically
2022ALiBi"Train Short, Test Long" (Press et al.)Linear bias, zero-shot extrapolation
2023Position InterpolationChen et al.Scale positions to extend context
2023NTK-aware ScalingReddit/communityNon-uniform frequency scaling
2023YaRNPeng et al.Combined interpolation + attention scaling
2024LongRoPEMicrosoftSearch-based optimal scaling factors
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    EVOLUTION OF POSITION ENCODING                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  2017 ──────────────────────────────────────────────────────────► 2024  │
│                                                                          │
│  ABSOLUTE                      RELATIVE                    ROTATION     │
│  ────────                      ────────                    ────────     │
│                                                                          │
│  Sinusoidal ──► Learned ──► Transformer-XL ──► T5 ──► ALiBi ──► RoPE   │
│      │              │              │            │        │        │     │
│      │              │              │            │        │        │     │
│  Add to         Add to        Modify        Scalar    Linear   Rotate  │
│  embedding      embedding     attention     bias      bias     Q and K │
│                               score                                     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  KEY INSIGHT PROGRESSION:                                                │
│                                                                          │
│  1. "Position exists" (sinusoidal/learned)                             │
│  2. "Relative position matters more" (Transformer-XL, T5)              │
│  3. "Make relative position mathematical" (RoPE)                       │
│  4. "Enable length extrapolation" (ALiBi, RoPE scaling)                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part I: The Problem

Attention is Permutation-Equivariant

To understand why positional embeddings are necessary, we need to understand what "permutation-equivariant" means for attention.

Consider the self-attention operation. For each position, we compute:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

The query QQ, key KK, and value VV matrices are linear projections of the input:

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_V

Now imagine we permute the input sequence XX using some permutation matrix PP. The new input is PXPX. What happens to the attention output?

Q=PXWQ=PQQ' = PXW_Q = PQ K=PXWK=PKK' = PXW_K = PK V=PXWV=PVV' = PXW_V = PV

The attention scores become:

softmax(QKTdk)=softmax(PQKTPTdk)=Psoftmax(QKTdk)PT\text{softmax}\left(\frac{Q'K'^T}{\sqrt{d_k}}\right) = \text{softmax}\left(\frac{PQK^TP^T}{\sqrt{d_k}}\right) = P \cdot \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot P^T

And the output:

Output=Psoftmax(QKTdk)PTPV=POutput\text{Output}' = P \cdot \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot P^T \cdot PV = P \cdot \text{Output}

The output permutes exactly the same way as the input. The model cannot distinguish different orderings of the same tokens.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    WHY ATTENTION NEEDS POSITION                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  WITHOUT POSITION INFORMATION:                                           │
│  ─────────────────────────────                                           │
│                                                                          │
│  Input 1: "dog bites man"    →  Tokens: [dog, bites, man]              │
│  Input 2: "man bites dog"    →  Tokens: [man, bites, dog]              │
│                                                                          │
│  The attention mechanism sees:                                          │
│  • Same set of tokens: {dog, bites, man}                               │
│  • Same pairwise relationships between tokens                          │
│  • Cannot distinguish the two inputs!                                  │
│                                                                          │
│  Both produce equivalent representations (up to permutation).          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WITH POSITION INFORMATION:                                              │
│  ──────────────────────────                                              │
│                                                                          │
│  Input 1: "dog bites man"                                               │
│           [dog+pos0, bites+pos1, man+pos2]                             │
│                                                                          │
│  Input 2: "man bites dog"                                               │
│           [man+pos0, bites+pos1, dog+pos2]                             │
│                                                                          │
│  Now:                                                                   │
│  • "dog" at position 0 ≠ "dog" at position 2                          │
│  • The model can learn that position 0 is often the subject           │
│  • The model can learn that position 2 is often the object            │
│  • Different orderings produce different representations!              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

What Good Positional Embeddings Should Do

Not all positional embeddings are created equal. A good positional embedding scheme should:

  1. Uniquely identify each position: Position 5 should be distinguishable from position 10.

  2. Enable relative position reasoning: The model should be able to determine that positions 3 and 5 are 2 apart, and that positions 100 and 102 are also 2 apart.

  3. Generalize to unseen lengths: If trained on sequences up to length 1024, the model should ideally work (at least somewhat) on length 2048.

  4. Not interfere with semantic content: Position information shouldn't overwhelm or distort the meaning captured in token embeddings.

  5. Be computationally efficient: Position encoding shouldn't become a bottleneck.

Different approaches make different tradeoffs among these goals. Let's explore each major approach.


Part II: Absolute Positional Embeddings

Sinusoidal Positional Encoding (Original Transformer)

The original "Attention Is All You Need" paper introduced sinusoidal positional encodings. The idea was elegant: use sine and cosine functions of different frequencies to create unique position signatures.

For a position pospos and dimension ii:

PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)

PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

Where dmodeld_{model} is the embedding dimension.

Python
import numpy as np
import matplotlib.pyplot as plt

def sinusoidal_position_encoding(max_len, d_model):
    """
    Generate sinusoidal positional encodings.

    Args:
        max_len: Maximum sequence length
        d_model: Model dimension

    Returns:
        position_encoding: (max_len, d_model) array
    """
    position = np.arange(max_len)[:, np.newaxis]  # (max_len, 1)

    # Compute the division term: 10000^(2i/d_model)
    div_term = np.exp(
        np.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)
    )  # (d_model/2,)

    # Initialize encoding array
    pe = np.zeros((max_len, d_model))

    # Even indices: sine
    pe[:, 0::2] = np.sin(position * div_term)

    # Odd indices: cosine
    pe[:, 1::2] = np.cos(position * div_term)

    return pe

# Generate encodings
pe = sinusoidal_position_encoding(100, 512)

# Visualize the encoding matrix
def visualize_sinusoidal_encoding(pe, num_positions=50, num_dims=128):
    """
    Visualize the sinusoidal position encoding as a heatmap.

    The resulting image shows:
    - X-axis: Embedding dimensions (low = high freq, high = low freq)
    - Y-axis: Positions in the sequence
    - Color: Encoding value (-1 to 1)

    You'll see wave patterns that are:
    - Fast-oscillating on the left (high frequency dimensions)
    - Slow-oscillating on the right (low frequency dimensions)
    """
    plt.figure(figsize=(12, 6))
    plt.imshow(pe[:num_positions, :num_dims], cmap='RdBu', aspect='auto', vmin=-1, vmax=1)
    plt.colorbar(label='Encoding Value')
    plt.xlabel('Embedding Dimension')
    plt.ylabel('Position')
    plt.title('Sinusoidal Position Encoding Heatmap')
    plt.show()

# Analyze position similarity
def position_similarity_matrix(pe):
    """
    Compute cosine similarity between all position encodings.

    Key insight: Similar positions should have similar encodings.
    The diagonal is all 1s (position similar to itself).
    Off-diagonal values show how similar different positions are.
    """
    # Normalize each position vector
    norms = np.linalg.norm(pe, axis=1, keepdims=True)
    pe_normalized = pe / norms

    # Cosine similarity matrix
    similarity = pe_normalized @ pe_normalized.T
    return similarity

similarity = position_similarity_matrix(pe[:50])
# Positions near each other have higher similarity
# Positions far apart have lower (but not zero) similarity
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SINUSOIDAL ENCODING HEATMAP                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Dimension →   0      64     128    192    256    320    384    448     │
│                │       │       │       │       │       │       │        │
│  Position      ▼       ▼       ▼       ▼       ▼       ▼       ▼        │
│      0      ▓▓░░▓▓░░▓▓▓░░░▓▓▓▓░░░░▓▓▓▓▓░░░░░▓▓▓▓▓▓░░░░░░▓▓▓▓▓▓▓       │
│      1      ░░▓▓░░▓▓░░▓▓▓░░░▓▓▓▓░░░░▓▓▓▓░░░░░▓▓▓▓▓░░░░░░▓▓▓▓▓▓▓       │
│      2      ▓▓░░▓▓░░▓▓░░▓▓▓░░░▓▓▓▓░░░░▓▓▓▓░░░░░▓▓▓▓▓░░░░░░▓▓▓▓▓       │
│      3      ░░▓▓░░▓▓░░▓▓░░▓▓▓░░░▓▓▓▓░░░░▓▓▓▓░░░░░▓▓▓▓▓░░░░░░▓▓▓       │
│      .                                                                  │
│      .      Fast oscillation    Medium oscillation    Slow oscillation  │
│      .      (local patterns)    (medium-range)        (global patterns) │
│     49      ▓░▓░▓░▓░░▓▓░░▓▓░░░▓▓▓░░░░▓▓▓▓░░░░░▓▓▓▓▓░░░░░░▓▓▓▓▓▓       │
│                                                                          │
│  Legend: ▓ = positive (close to 1), ░ = negative (close to -1)         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  POSITION SIMILARITY MATRIX (cosine similarity):                        │
│                                                                          │
│             pos 0   pos 5   pos 10  pos 20  pos 40                      │
│  pos 0      1.00    0.92    0.71    0.38    0.12                       │
│  pos 5      0.92    1.00    0.92    0.54    0.21                       │
│  pos 10     0.71    0.92    1.00    0.71    0.31                       │
│  pos 20     0.38    0.54    0.71    1.00    0.54                       │
│  pos 40     0.12    0.21    0.31    0.54    1.00                       │
│                                                                          │
│  Nearby positions are more similar → model can learn local patterns    │
│  Similarity decays smoothly with distance                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Why these specific formulas? The choice of sine/cosine with exponentially decreasing frequencies was motivated by several insights:

  1. Unique positions: Each position gets a unique pattern across dimensions. Low dimensions have high-frequency oscillations (distinguishing nearby positions), while high dimensions have low-frequency oscillations (distinguishing distant positions).

  2. Relative position via linear transformation: The authors hypothesized that relative positions could be computed via linear transformations. For any fixed offset kk:

PEpos+k=f(PEpos)PE_{pos+k} = f(PE_{pos})

This is because sin(a+b)\sin(a+b) and cos(a+b)\cos(a+b) can be expressed as linear combinations of sin(a)\sin(a), cos(a)\cos(a), sin(b)\sin(b), and cos(b)\cos(b).

  1. Bounded values: Sine and cosine are bounded in [1,1][-1, 1], preventing position encodings from dominating token embeddings.

  2. Extrapolation potential: The formula can compute positions beyond the training range.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SINUSOIDAL POSITION ENCODING                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  INTUITION: Different frequencies capture different position scales     │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  Position:    0    1    2    3    4    5    6    7    8    9   10      │
│                                                                          │
│  Dim 0 (sin): ────╮  ╭────╮  ╭────╮  ╭────╮  ╭────╮  ╭────╮           │
│  High freq        ╰──╯    ╰──╯    ╰──╯    ╰──╯    ╰──╯    ╰──          │
│               Oscillates rapidly → distinguishes adjacent positions     │
│                                                                          │
│  Dim 256:    ──────────────────╮                                        │
│  Med freq                      ╰────────────────────────────────        │
│               Oscillates slowly → captures medium-range patterns        │
│                                                                          │
│  Dim 510:    ────────────────────────────────────────────────────       │
│  Low freq    Nearly constant over short ranges                          │
│               Oscillates very slowly → captures global position         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMBINED: Each position has a unique "fingerprint" across all dims    │
│                                                                          │
│  Position 0: [0.00, 1.00, 0.00, 1.00, 0.00, 1.00, ...]                │
│  Position 1: [0.84, 0.54, 0.04, 0.99, 0.002, 0.99, ...]               │
│  Position 2: [0.91, -0.42, 0.08, 0.99, 0.004, 0.99, ...]              │
│                                                                          │
│  The pattern varies across dimensions, creating unique signatures.     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Limitations of sinusoidal encodings:

Despite the elegant theory, sinusoidal encodings have practical limitations:

  1. No learned relative position: While the linear transformation property exists mathematically, the model has to learn to exploit it. In practice, models don't naturally learn to compute relative positions this way.

  2. Poor extrapolation in practice: Even though positions can be computed for any length, models trained with sinusoidal encodings don't generalize well beyond their training length. The model hasn't learned what to do with position patterns it's never seen.

  3. Position added to content: Position information is added to token embeddings, potentially interfering with semantic content.

Learned Positional Embeddings (GPT-2, BERT)

A simpler alternative: just learn a separate embedding for each position.

Python
import torch
import torch.nn as nn

class LearnedPositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        # Learnable embedding table: one vector per position
        self.embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model) - token embeddings
        Returns:
            (batch_size, seq_len, d_model) - embeddings + positions
        """
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        pos_embeddings = self.embedding(positions)  # (seq_len, d_model)
        return x + pos_embeddings

This approach is used by GPT-2, BERT, and many other models from that era. It's simple and lets the model learn whatever position representation works best for its task.

Advantages:

  • Simple to implement
  • Model can learn optimal position patterns for its specific task
  • No assumptions about what position should look like

Limitations:

  • Hard maximum length: The embedding table has a fixed size. Position 1025 simply doesn't exist if max_len is 1024.
  • No extrapolation: Cannot handle sequences longer than training.
  • No explicit relative position: Positions are absolute; the model must learn relative relationships implicitly.
  • More parameters: Adds max_len×dmodel\text{max\_len} \times d_{model} parameters (though typically small compared to the full model).
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    LEARNED POSITIONAL EMBEDDINGS                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ARCHITECTURE:                                                           │
│  ─────────────                                                           │
│                                                                          │
│  Position:     0        1        2       ...     1023                   │
│                │        │        │                │                      │
│                ▼        ▼        ▼                ▼                      │
│             ┌─────┐  ┌─────┐  ┌─────┐         ┌─────┐                   │
│             │ E_0 │  │ E_1 │  │ E_2 │   ...   │E_1023│                  │
│             └─────┘  └─────┘  └─────┘         └─────┘                   │
│                │        │        │                │                      │
│           Each is a learned d_model-dimensional vector                 │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  USAGE:                                                                  │
│                                                                          │
│  Token embeddings:    [0.1, -0.2, 0.3, ...]  [0.4, 0.1, -0.2, ...]    │
│                              +                        +                  │
│  Position embeddings: [0.05, 0.1, -0.1, ...] [0.08, -0.05, 0.2, ...]  │
│                              =                        =                  │
│  Combined:            [0.15, -0.1, 0.2, ...] [0.48, 0.05, 0.0, ...]   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  THE HARD LIMIT PROBLEM:                                                 │
│                                                                          │
│  max_len = 1024 during training                                         │
│                                                                          │
│  At inference with 2000 tokens:                                         │
│  Position 1024 → ??? NO EMBEDDING EXISTS                               │
│  Position 1025 → ??? NO EMBEDDING EXISTS                               │
│  ...                                                                    │
│                                                                          │
│  Options:                                                               │
│  1. Truncate to 1024 tokens (lose information)                        │
│  2. Use sliding window (complex, loses global context)                │
│  3. Initialize new embeddings and fine-tune (expensive)               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

The Problem with Absolute Positions

Both sinusoidal and learned embeddings share a fundamental limitation: they encode absolute positions. Position 5 always gets the same encoding, regardless of context.

But language often cares more about relative positions. Consider:

  • "The cat that I saw yesterday was cute" - "yesterday" modifies "saw", which is 2 positions back
  • "Yesterday, the cat that I saw was cute" - "yesterday" still modifies "saw", now 4 positions forward

The relationship between "yesterday" and "saw" matters, not their absolute positions in the sentence.

With absolute position encodings, the model must learn that "position 5 attending to position 3" is similar to "position 100 attending to position 98." This is learnable but inefficient—it requires the model to rediscover the same relative relationship at every absolute position.


Part III: Relative Positional Embeddings

The Key Insight

What if, instead of encoding absolute positions, we directly encoded the relationship between positions? When computing attention between position ii and position jj, we could inject information about their relative distance (ij)(i - j).

This insight led to several important methods.

Transformer-XL Relative Position

Transformer-XL (Dai et al., 2019) introduced a systematic way to incorporate relative position into attention. The key idea: modify the attention score computation to depend on relative position.

In standard attention:

scoreij=qiTkj\text{score}_{ij} = q_i^T k_j

In Transformer-XL:

scoreij=qiTkj+qiTrij+uTkj+vTrij\text{score}_{ij} = q_i^T k_j + q_i^T r_{i-j} + u^T k_j + v^T r_{i-j}

Where:

  • rijr_{i-j} is a learned embedding for relative position (ij)(i-j)
  • uu and vv are learned global bias vectors
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    TRANSFORMER-XL ATTENTION DECOMPOSITION                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  score(i,j) = (a) + (b) + (c) + (d)                                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  (a) q_i^T · k_j                                                        │
│      └── Content-to-content: Does query content match key content?     │
│                                                                          │
│  (b) q_i^T · r_{i-j}                                                    │
│      └── Content-to-position: Does query content prefer this distance? │
│                                                                          │
│  (c) u^T · k_j                                                          │
│      └── Global content bias: Is this key generally important?         │
│                                                                          │
│  (d) v^T · r_{i-j}                                                      │
│      └── Global position bias: Is this distance generally important?   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RELATIVE POSITION EMBEDDINGS r_{i-j}:                                  │
│                                                                          │
│  Distance:   -5    -4    -3    -2    -1     0     1     2     3  ...  │
│               │     │     │     │     │     │     │     │     │        │
│               ▼     ▼     ▼     ▼     ▼     ▼     ▼     ▼     ▼        │
│            [r_-5][r_-4][r_-3][r_-2][r_-1][r_0 ][r_1 ][r_2 ][r_3 ]      │
│                                                                          │
│  Each relative distance has its own learned embedding.                 │
│  Position (5,3) and (100,98) both use r_{-2} → same relationship!     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Advantages:

  • Directly models relative position
  • Same relative distance always gets the same encoding
  • Better length generalization than absolute methods

Limitations:

  • More complex implementation
  • Requires storing relative position embeddings (typically clipped to a maximum distance)
  • Adds computational overhead to attention

T5's Relative Position Bias

T5 (Raffel et al., 2020) simplified relative position encoding. Instead of modifying the full attention computation, T5 adds a learned scalar bias based on relative position:

scoreij=qiTkj+bij\text{score}_{ij} = q_i^T k_j + b_{i-j}

Where bijb_{i-j} is a learned scalar (not a vector) for each relative position. To handle arbitrary distances efficiently, T5 uses bucketed relative positions:

Python
def relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
    """
    T5's bucketing scheme: exact positions for small distances,
    logarithmic buckets for larger distances.
    """
    # Separate positive and negative positions
    relative_buckets = 0

    # Handle negative positions (use half the buckets)
    num_buckets //= 2
    relative_buckets += (relative_position < 0).long() * num_buckets
    relative_position = torch.abs(relative_position)

    # Exact buckets for small distances
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    # Logarithmic buckets for large distances
    relative_position_if_large = max_exact + (
        torch.log(relative_position.float() / max_exact) /
        torch.log(torch.tensor(max_distance / max_exact)) *
        (num_buckets - max_exact)
    ).long()

    relative_position_if_large = torch.min(
        relative_position_if_large,
        torch.full_like(relative_position_if_large, num_buckets - 1)
    )

    relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
    return relative_buckets


class T5RelativeAttention(nn.Module):
    """
    Complete T5-style relative position attention.

    The key difference from standard attention:
    - Adds a learned scalar bias b_{i-j} to attention scores
    - Bias depends only on relative position (i-j)
    - Uses bucketing to handle long distances with limited parameters
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        num_buckets: int = 32,
        max_distance: int = 128,
        is_decoder: bool = True
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.is_decoder = is_decoder

        # Standard attention projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        # Relative position bias: one scalar per (bucket, head) pair
        # Shape: (n_heads, num_buckets)
        self.relative_attention_bias = nn.Embedding(num_buckets, n_heads)

    def compute_bias(self, query_length: int, key_length: int) -> torch.Tensor:
        """
        Compute relative position bias matrix.

        Returns: (1, n_heads, query_length, key_length)
        """
        # Create position indices
        query_positions = torch.arange(query_length)
        key_positions = torch.arange(key_length)

        # Relative positions: (query_length, key_length)
        relative_positions = key_positions.unsqueeze(0) - query_positions.unsqueeze(1)

        # Convert to buckets
        buckets = relative_position_bucket(
            relative_positions,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance
        )

        # Look up bias values: (query_length, key_length, n_heads)
        bias = self.relative_attention_bias(buckets)

        # Reshape to (1, n_heads, query_length, key_length)
        return bias.permute(2, 0, 1).unsqueeze(0)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)

        # Transpose for attention: (batch, n_heads, seq_len, head_dim)
        q, k, v = [t.transpose(1, 2) for t in (q, k, v)]

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Add relative position bias (the key T5 innovation!)
        position_bias = self.compute_bias(seq_len, seq_len).to(x.device)
        scores = scores + position_bias

        if mask is not None:
            scores = scores + mask

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        return self.out_proj(output)
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    T5 RELATIVE POSITION BUCKETS                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  PROBLEM: We can't have a separate bias for every possible distance    │
│           (could be thousands of positions apart)                       │
│                                                                          │
│  SOLUTION: Bucket distances into a fixed number of categories          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BUCKETING SCHEME (32 buckets, 16 for each direction):                 │
│                                                                          │
│  Distance:   -∞  ... -128  -64  -32  -16  -8  -4  -2  -1   0            │
│  Bucket:      0  ...   0    1    2    3   4   5   6   7   8            │
│                                                                          │
│  Distance:    1    2    4    8   16   32   64  128  ... +∞             │
│  Bucket:      9   10   11   12   13   14   15   16  ... 16             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  KEY INSIGHT:                                                            │
│                                                                          │
│  • Exact buckets for small distances (±8 positions)                    │
│    → Precise local relationships matter                                │
│                                                                          │
│  • Logarithmic buckets for large distances                             │
│    → "Far away" is "far away"; exact distance matters less            │
│    → Distances 100 and 150 go in the same bucket                      │
│                                                                          │
│  This matches linguistic intuition: nearby words have precise          │
│  relationships, while distant words just need "somewhere back there"   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Advantages:

  • Simple: just add a scalar bias
  • Memory efficient: fixed number of buckets regardless of sequence length
  • Captures the intuition that exact distance matters locally but not globally

Limitations:

  • Loses precise distance information for far positions
  • Requires choosing bucketing hyperparameters

ALiBi (Attention with Linear Biases)

ALiBi (Press et al., 2022) took an even simpler approach: don't learn position embeddings at all. Instead, apply a fixed linear penalty based on distance:

scoreij=qiTkjmij\text{score}_{ij} = q_i^T k_j - m \cdot |i - j|

Where mm is a fixed (not learned) slope that differs per attention head.

Python
def get_alibi_slopes(num_heads: int) -> list[float]:
    """
    ALiBi slopes: geometric sequence from 2^(-8/n) to 2^(-8).
    Different heads attend at different "ranges".

    The formula ensures:
    - Head 0 has steepest slope (most local attention)
    - Last head has gentlest slope (most global attention)
    - Slopes form a geometric sequence

    For 8 heads: [0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625]
    """
    ratio = 2 ** (-8 / num_heads)
    return [ratio ** i for i in range(1, num_heads + 1)]


def build_alibi_bias(seq_len: int, num_heads: int, dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """
    Build the full ALiBi bias tensor.

    Returns: (1, num_heads, seq_len, seq_len)
    """
    slopes = get_alibi_slopes(num_heads)

    # Create position indices
    positions = torch.arange(seq_len)

    # Distance matrix: (seq_len, seq_len)
    # For causal attention, we only care about positions where j <= i
    # distance[i,j] = i - j (positive for j < i, 0 for j = i, negative for j > i)
    distance = positions.unsqueeze(0) - positions.unsqueeze(1)

    # For each head, multiply by its slope
    # Shape: (num_heads, seq_len, seq_len)
    slopes_tensor = torch.tensor(slopes, dtype=dtype).view(num_heads, 1, 1)
    bias = -slopes_tensor * distance.abs().unsqueeze(0).to(dtype)

    return bias.unsqueeze(0)  # (1, num_heads, seq_len, seq_len)


class ALiBiAttention(nn.Module):
    """
    Complete ALiBi attention module.

    Key features:
    - No learned position parameters (zero additional parameters!)
    - Linear bias computed from fixed slopes
    - Naturally extrapolates to longer sequences
    - Different heads specialize for different attention ranges
    """

    def __init__(self, d_model: int, n_heads: int, max_seq_len: int = 2048):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.max_seq_len = max_seq_len

        # Standard attention projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        # Precompute ALiBi bias (not learned!)
        # Can be extended at runtime if needed
        self.register_buffer(
            "alibi_bias",
            build_alibi_bias(max_seq_len, n_heads),
            persistent=False
        )

    def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)

        q, k, v = [t.transpose(1, 2) for t in (q, k, v)]

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Add ALiBi bias (the only position information!)
        if seq_len <= self.max_seq_len:
            alibi = self.alibi_bias[:, :, :seq_len, :seq_len]
        else:
            # Extrapolate: just compute a larger bias matrix
            # This is where ALiBi shines - no retraining needed!
            alibi = build_alibi_bias(seq_len, self.n_heads, dtype=x.dtype).to(x.device)

        scores = scores + alibi

        # Apply causal mask if needed
        if causal:
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device),
                diagonal=1
            )
            scores = scores.masked_fill(causal_mask, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        return self.out_proj(output)
Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ALiBi: LINEAR ATTENTION BIAS                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE IDEA: Penalize attention to distant positions with a linear bias  │
│                                                                          │
│  score(i,j) = q_i · k_j - m × |i - j|                                  │
│                           └── Linear penalty for distance              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DIFFERENT SLOPES FOR DIFFERENT HEADS:                                  │
│                                                                          │
│  Head 0:  m = 0.5    (gentle slope → can attend far)                   │
│  Head 1:  m = 0.25   (gentler → attends even farther)                  │
│  Head 2:  m = 0.125  (very gentle → near-global attention)             │
│  ...                                                                    │
│  Head 7:  m = 0.008  (almost flat → truly global attention)            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  VISUALIZATION (8 heads):                                                │
│                                                                          │
│  Attention │    Head 7 (m=0.008) ────────────────────                  │
│  Weight    │    Head 3 (m=0.06) ───────────────                        │
│            │    Head 0 (m=0.5) ─────────                               │
│            │                                                            │
│            └────────────────────────────────────── Distance            │
│                                                                          │
│  Each head specializes: some for local patterns, some for global.     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY IT WORKS FOR LENGTH GENERALIZATION:                                │
│                                                                          │
│  • No learned parameters → no unseen positions                         │
│  • Linear penalty naturally extends to any distance                    │
│  • Train on 1024, use at 2048: the math just works                    │
│                                                                          │
│  Results: Models trained with ALiBi on 1024 tokens can extrapolate    │
│  to 2048+ with minimal performance degradation.                        │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Advantages:

  • Zero learned parameters for position
  • Naturally extrapolates to longer sequences
  • Simple to implement
  • Different heads capture different ranges

Limitations:

  • Linear assumption may not capture all positional relationships
  • Less flexible than learned approaches
  • Doesn't distinguish "3 positions left" from "3 positions right" (uses absolute distance)
  • Has been largely superseded by RoPE in practice

Part IV: The Modern Era - RoPE

Why Something New Was Needed

By 2022, the limitations of existing approaches were becoming clear:

  • Absolute embeddings (sinusoidal/learned): Poor length generalization
  • Relative embeddings (Transformer-XL, T5): Complex, added overhead
  • ALiBi: Good extrapolation but inflexible

What the field needed was a method that:

  1. Encoded relative position naturally
  2. Was computationally efficient
  3. Could be extended to longer contexts
  4. Integrated smoothly with attention

Rotary Position Embedding (RoPE) achieved all of these goals and has become the dominant method in modern LLMs including Llama, Mistral, Qwen, and many others.

The Core Idea

RoPE's key insight: instead of adding position information to embeddings, rotate them. The rotation angle depends on position.

When two rotated vectors are compared via dot product (as in attention), the rotation angles subtract, making the result depend only on relative position.

We cover RoPE in full detail in our dedicated post: RoPE: Rotary Position Embeddings Explained. Here's the key insight:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    RoPE: THE KEY INSIGHT                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  STANDARD APPROACH: Add position to embeddings                          │
│                                                                          │
│  x_with_pos = x + position_embedding                                   │
│                                                                          │
│  Problem: Position information mixes with content in complex ways      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ROPE: Rotate embeddings by position-dependent angle                    │
│                                                                          │
│  q_rotated = Rotate(q, position × θ)                                   │
│  k_rotated = Rotate(k, position × θ)                                   │
│                                                                          │
│  WHY IT WORKS:                                                          │
│                                                                          │
│  When we compute attention: q_m · k_n                                  │
│                                                                          │
│  q_m is rotated by angle m×θ                                           │
│  k_n is rotated by angle n×θ                                           │
│                                                                          │
│  Due to rotation properties:                                            │
│  q_m · k_n = |q||k|cos(angle_between)                                  │
│            = f(original_content, (m-n)×θ)                              │
│                                      └── Only relative position!       │
│                                                                          │
│  The absolute positions m and n don't matter individually—             │
│  only their difference (m-n) affects the attention score.             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

RoPE is now so important that it has its own ecosystem of scaling methods (Position Interpolation, NTK-aware scaling, YaRN, LongRoPE) that enable context extension. See our Context Extension post for details.


Part V: Comparison and Recommendations

Method Comparison

MethodTypeLength GeneralizationRelative PositionComputational OverheadUsed By
SinusoidalAbsolutePoorIndirectNoneOriginal Transformer
LearnedAbsoluteNone (hard limit)ImplicitNoneGPT-2, BERT
Transformer-XLRelativeGoodDirectModerateTransformer-XL
T5 BiasRelativeGoodDirect (bucketed)LowT5, FLAN-T5
ALiBiRelativeExcellentDirect (linear)LowBLOOM, MPT
RoPERelativeGood (scalable)DirectLowLlama, Mistral, Qwen

When to Use What

For new LLM projects: Use RoPE. It's the modern standard, well-understood, and has proven scaling methods for long contexts.

For encoder models (BERT-style): Learned embeddings are still common and work well for fixed-length tasks like classification.

For extreme length generalization without fine-tuning: ALiBi can be useful, though RoPE with scaling often performs better with minimal fine-tuning.

For research/experimentation: Consider whether relative position matters for your task. If you're doing tasks where absolute position is meaningful (like slot filling at specific positions), absolute methods might still be appropriate.


Part VI: Benchmarks and Empirical Comparisons

Perplexity at Different Context Lengths

The key metric for position encoding is how well models perform as context length increases, especially beyond training length.

MethodTrained Length1× (in-distribution)
Sinusoidal51218.2142.31847OOM
Learned51217.9N/AN/AN/A
ALiBi51218.519.221.428.7
RoPE51217.889.44122341
RoPE + PI51217.818.419.723.1
RoPE + YaRN51217.818.118.920.2

Note: These are illustrative numbers showing typical patterns. Actual results vary by model size, dataset, and implementation.

Key observations:

  1. Learned embeddings cannot extrapolate at all (positions don't exist)
  2. Sinusoidal degrades catastrophically beyond training length
  3. ALiBi extrapolates well zero-shot due to simple linear bias
  4. RoPE degrades without scaling but with PI/YaRN matches or beats ALiBi

Length Generalization: Why Some Methods Fail

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    WHY EXTRAPOLATION FAILS                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  LEARNED EMBEDDINGS:                                                     │
│  ───────────────────                                                     │
│  Position 512 → Embedding lookup → vector                              │
│  Position 513 → ??? No entry in table → CRASH or garbage               │
│                                                                          │
│  SINUSOIDAL EMBEDDINGS:                                                  │
│  ──────────────────────                                                  │
│  Position 512 → sin/cos computation → vector ✓                         │
│  Position 513 → sin/cos computation → vector ✓ (mathematically valid)  │
│                                                                          │
│  But! The model has never seen these patterns during training.          │
│  High-frequency dimensions oscillate in unfamiliar ways.                │
│  Attention patterns become unpredictable → performance collapses.      │
│                                                                          │
│  ROPE WITHOUT SCALING:                                                   │
│  ─────────────────────                                                   │
│  Position 512 → Rotation by 512θ → ✓ (seen during training)           │
│  Position 513 → Rotation by 513θ → Angle slightly out of distribution │
│  Position 2048 → Rotation by 2048θ → Angle WAY out of distribution    │
│                                                                          │
│  The rotation angles become larger than anything seen in training.     │
│  Softmax over unfamiliar attention patterns → degradation.             │
│                                                                          │
│  ALiBi:                                                                  │
│  ──────                                                                  │
│  Position 512 vs 0 → Bias of -512m (seen in training)                 │
│  Position 513 vs 0 → Bias of -513m (just slightly larger)             │
│  Position 2048 vs 0 → Bias of -2048m (larger but same linear form)    │
│                                                                          │
│  Linear extrapolation! The bias formula doesn't change.                │
│  Attention still prefers nearby tokens (just a bit less strongly).     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Computational Cost Comparison

MethodExtra ParametersFLOPs per TokenMemoryNotes
Sinusoidal0NegligibleO(L×d)Precomputed once
LearnedL×dNegligibleO(L×d)Lookup only
T5 Biasbuckets×headsO(L²×h)O(L²×h)Bias matrix
ALiBi0O(L²×h)O(L²×h)Computed bias
RoPE0O(L×d)O(L×d)Per-vector rotation

Where L = sequence length, d = model dimension, h = number of heads.

Takeaway: RoPE has the best efficiency profile—no extra parameters and O(L×d) compute (linear in sequence length for the position encoding itself).


Part VII: Position Encoding in Cross-Attention

Cross-attention (used in encoder-decoder models like T5, BART, or when LLMs attend to retrieved documents) introduces a subtlety: the query and key come from different sequences.

The Challenge

In self-attention:

  • Query position: 5
  • Key position: 3
  • Relative position: 5 - 3 = 2

In cross-attention:

  • Query from decoder position: 5
  • Key from encoder position: 3
  • Relative position: ??? (different sequences!)

Solutions by Method

Learned/Sinusoidal: Add position embeddings to encoder and decoder separately. Cross-attention sees both absolute positions, but they're not directly comparable.

T5: Uses relative position bias in self-attention only. Cross-attention has no position bias (attends purely based on content).

ALiBi: Typically applied to self-attention only. Cross-attention is position-agnostic.

RoPE: Options vary:

  1. No RoPE in cross-attention: Keys are unrotated, queries are rotated. Position information comes only from the query side.
  2. Separate RoPE: Encoder and decoder have independent position indices. Relative position within each sequence is preserved, but cross-sequence doesn't have relative semantics.
  3. Unified positions: Concatenate encoder+decoder positions. Requires knowing both sequences upfront.
Python
class CrossAttentionWithRoPE(nn.Module):
    """
    Cross-attention with RoPE applied to queries only.

    This is the common approach: the decoder query knows its position,
    but encoder keys are position-agnostic in cross-attention.
    """

    def __init__(self, d_model: int, n_heads: int, max_len: int = 4096):
        super().__init__()
        self.head_dim = d_model // n_heads
        self.n_heads = n_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        # RoPE frequencies for decoder positions
        self.register_buffer("freqs", precompute_rope_freqs(self.head_dim, max_len))

    def forward(
        self,
        decoder_hidden: torch.Tensor,  # (batch, dec_len, d_model)
        encoder_hidden: torch.Tensor,  # (batch, enc_len, d_model)
        decoder_positions: torch.Tensor  # (batch, dec_len) position indices
    ) -> torch.Tensor:
        batch, dec_len, _ = decoder_hidden.shape
        enc_len = encoder_hidden.shape[1]

        # Project
        q = self.q_proj(decoder_hidden).view(batch, dec_len, self.n_heads, self.head_dim)
        k = self.k_proj(encoder_hidden).view(batch, enc_len, self.n_heads, self.head_dim)
        v = self.v_proj(encoder_hidden).view(batch, enc_len, self.n_heads, self.head_dim)

        # Apply RoPE to queries only (decoder knows its position)
        # Keys are NOT rotated (encoder positions don't matter in cross-attention)
        q = apply_rope(q, self.freqs, decoder_positions)

        # Standard attention from here
        q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        return self.out_proj(out.transpose(1, 2).reshape(batch, dec_len, -1))

Best Practices for Cross-Attention

  1. Encoder self-attention: Use full position encoding (RoPE, ALiBi, etc.)
  2. Decoder self-attention: Use full position encoding
  3. Cross-attention: Often no position bias, or query-side only

This reflects the intuition that:

  • Within a sequence, relative position matters (syntax, semantics)
  • Across sequences, content matching matters more than position

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles