Sparse Attention Patterns: Longformer, BigBird, and Beyond
A comprehensive deep dive into sparse attention mechanisms—how Longformer, BigBird, and other architectures break the O(n²) attention barrier. Understand local attention, global tokens, random attention, and when to use each pattern.
Table of Contents
The Quadratic Attention Problem
Standard transformer attention has a fundamental scaling problem: every token attends to every other token, creating an O(n²) computation and memory cost. Double the sequence length, and attention costs quadruple.
┌─────────────────────────────────────────────────────────────────────────┐
│ THE O(n²) PROBLEM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ATTENTION MATRIX SIZE: │
│ ───────────────────── │
│ │
│ Sequence Length Attention Matrix Memory (FP16) │
│ ───────────────────────────────────────────────────── │
│ 512 512 × 512 0.5 MB │
│ 2,048 2K × 2K 8 MB │
│ 8,192 8K × 8K 128 MB │
│ 32,768 32K × 32K 2 GB │
│ 131,072 128K × 128K 32 GB │
│ │
│ Per attention head! Multiply by num_heads × num_layers. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ REAL MEMORY IMPACT (32 layers, 32 heads): │
│ │
│ Sequence Length Total Attention Memory │
│ ───────────────────────────────────────── │
│ 2K 8 GB │
│ 8K 128 GB │
│ 32K 2 TB │
│ │
│ Without optimizations, 32K context is impossible! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMPUTE SCALING: │
│ │
│ Attention FLOPs ∝ n² × d │
│ │
│ At 8K context: 64× more compute than 1K │
│ At 128K context: 16,384× more compute than 1K │
│ │
│ This is unsustainable. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Solutions to Quadratic Attention
Several approaches address this:
┌─────────────────────────────────────────────────────────────────────────┐
│ APPROACHES TO O(n²) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. FLASHATTENTION (Keep full attention, optimize memory) │
│ • Still O(n²) compute │
│ • O(n) memory via tiling │
│ • Best when you need full attention │
│ • Covered in: attention-mechanisms-flash-attention-deep-dive │
│ │
│ 2. SPARSE ATTENTION (Attend to subset of tokens) │
│ • O(n) or O(n√n) compute and memory │
│ • Attends only to "important" positions │
│ • This post's focus │
│ │
│ 3. LINEAR ATTENTION (Approximate attention kernel) │
│ • O(n) compute and memory │
│ • Approximates softmax with kernel tricks │
│ • Examples: Performers, Linear Transformers │
│ │
│ 4. STATE SPACE MODELS (Different architecture) │
│ • O(n) or O(n log n) compute │
│ • Recurrent-like structure │
│ • Examples: Mamba, RWKV │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY SPARSE ATTENTION? │
│ ───────────────────── │
│ │
│ Key observation: Not all attention connections are equally useful. │
│ │
│ In practice, attention is often: │
│ • Local: Nearby tokens matter most │
│ • Sparse: Few tokens get high attention weight │
│ • Structured: Certain positions (start, special tokens) matter │
│ │
│ If we can identify which connections matter, we can skip the rest. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part I: Attention Pattern Taxonomy
Basic Sparse Patterns
┌─────────────────────────────────────────────────────────────────────────┐
│ SPARSE ATTENTION PATTERNS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ In these diagrams: │
│ ■ = attention connection (query attends to key) │
│ □ = no attention (masked out) │
│ Rows = query positions, Columns = key positions │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 1. FULL ATTENTION (Standard Transformer): │
│ ───────────────────────────────────────── │
│ │
│ Keys: 0 1 2 3 4 5 6 7 │
│ Q 0 ■ □ □ □ □ □ □ □ (causal: attend to past only) │
│ u 1 ■ ■ □ □ □ □ □ □ │
│ e 2 ■ ■ ■ □ □ □ □ □ │
│ r 3 ■ ■ ■ ■ □ □ □ □ │
│ i 4 ■ ■ ■ ■ ■ □ □ □ │
│ e 5 ■ ■ ■ ■ ■ ■ □ □ │
│ s 6 ■ ■ ■ ■ ■ ■ ■ □ │
│ 7 ■ ■ ■ ■ ■ ■ ■ ■ │
│ │
│ Complexity: O(n²) │
│ Every token attends to all previous tokens. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 2. LOCAL/SLIDING WINDOW ATTENTION: │
│ ────────────────────────────────── │
│ │
│ Keys: 0 1 2 3 4 5 6 7 │
│ Q 0 ■ ■ □ □ □ □ □ □ (window size = 3) │
│ u 1 ■ ■ ■ □ □ □ □ □ │
│ e 2 □ ■ ■ ■ □ □ □ □ │
│ r 3 □ □ ■ ■ ■ □ □ □ │
│ i 4 □ □ □ ■ ■ ■ □ □ │
│ e 5 □ □ □ □ ■ ■ ■ □ │
│ s 6 □ □ □ □ □ ■ ■ ■ │
│ 7 □ □ □ □ □ □ ■ ■ │
│ │
│ Complexity: O(n × w) where w = window size │
│ Each token attends to w neighbors. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 3. DILATED/STRIDED ATTENTION: │
│ ───────────────────────────── │
│ │
│ Keys: 0 1 2 3 4 5 6 7 │
│ Q 0 ■ □ ■ □ □ □ □ □ (stride = 2) │
│ u 1 □ ■ □ ■ □ □ □ □ │
│ e 2 ■ □ ■ □ ■ □ □ □ │
│ r 3 □ ■ □ ■ □ ■ □ □ │
│ i 4 ■ □ ■ □ ■ □ ■ □ │
│ e 5 □ ■ □ ■ □ ■ □ ■ │
│ s 6 □ □ ■ □ ■ □ ■ □ │
│ 7 □ □ □ ■ □ ■ □ ■ │
│ │
│ Complexity: O(n × n/stride) │
│ Attends to every k-th token. Captures longer range. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 4. GLOBAL ATTENTION (Selected tokens attend to all): │
│ ──────────────────────────────────────────────────── │
│ │
│ Keys: 0 1 2 3 4 5 6 7 │
│ Q 0 ■ ■ ■ ■ ■ ■ ■ ■ ← Global token (e.g., [CLS]) │
│ u 1 ■ ■ □ □ □ □ □ □ │
│ e 2 ■ □ ■ □ □ □ □ □ │
│ r 3 ■ □ □ ■ □ □ □ □ │
│ i 4 ■ □ □ □ ■ □ □ □ │
│ e 5 ■ □ □ □ □ ■ □ □ │
│ s 6 ■ □ □ □ □ □ ■ □ │
│ 7 ■ □ □ □ □ □ □ ■ │
│ │
│ Token 0 attends to everyone, everyone attends to token 0. │
│ Complexity: O(n × g + n) where g = number of global tokens │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part II: Longformer
The Longformer Architecture
Longformer (Beltagy et al., 2020) combines local sliding window attention with global attention for specific tokens:
┌─────────────────────────────────────────────────────────────────────────┐
│ LONGFORMER ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ LONGFORMER COMBINES: │
│ ──────────────────── │
│ │
│ 1. Local sliding window attention (for most tokens) │
│ 2. Global attention (for special tokens like [CLS]) │
│ 3. Optional dilated sliding window (for some layers) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ATTENTION PATTERN (16 tokens, window=4, global=[CLS]): │
│ ─────────────────────────────────────────────────────── │
│ │
│ Keys │
│ [CLS] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 │
│ [CLS] ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ 1 ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ Q │
│ 2 ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ u │
│ 3 ■ □ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ e │
│ 4 ■ □ □ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ r │
│ 5 ■ □ □ □ ■ ■ ■ ■ □ □ □ □ □ □ □ □ i │
│ 6 ■ □ □ □ □ ■ ■ ■ ■ □ □ □ □ □ □ □ e │
│ 7 ■ □ □ □ □ □ ■ ■ ■ ■ □ □ □ □ □ □ s │
│ 8 ■ □ □ □ □ □ □ ■ ■ ■ ■ □ □ □ □ □ │
│ 9 ■ □ □ □ □ □ □ □ ■ ■ ■ ■ □ □ □ □ │
│ 10 ■ □ □ □ □ □ □ □ □ ■ ■ ■ ■ □ □ □ │
│ 11 ■ □ □ □ □ □ □ □ □ □ ■ ■ ■ ■ □ □ │
│ 12 ■ □ □ □ □ □ □ □ □ □ □ ■ ■ ■ ■ □ │
│ 13 ■ □ □ □ □ □ □ □ □ □ □ □ ■ ■ ■ ■ │
│ 14 ■ □ □ □ □ □ □ □ □ □ □ □ □ ■ ■ ■ │
│ 15 ■ □ □ □ □ □ □ □ □ □ □ □ □ □ ■ ■ │
│ │
│ First column: Global attention to [CLS] │
│ First row: [CLS] attends globally │
│ Diagonal bands: Local sliding window │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMPLEXITY ANALYSIS: │
│ ──────────────────── │
│ │
│ For sequence length n, window size w, g global tokens: │
│ │
│ Local attention: n × w │
│ Global attention: 2 × g × n (global tokens + tokens attending to global)
│ │
│ Total: O(n × w + g × n) = O(n × (w + g)) │
│ │
│ For w=512, g=1, n=4096: │
│ Full attention: 4096² = 16.7M │
│ Longformer: 4096 × 513 ≈ 2.1M │
│ 8× reduction! │
│ │
│ For w=512, g=1, n=16384: │
│ Full attention: 16384² = 268M │
│ Longformer: 16384 × 513 ≈ 8.4M │
│ 32× reduction! │
│ │
│ Savings increase with sequence length. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Dilated Sliding Window
Longformer can use dilated attention in higher layers to increase receptive field:
┌─────────────────────────────────────────────────────────────────────────┐
│ DILATED SLIDING WINDOW │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE IDEA: │
│ ───────── │
│ │
│ In lower layers: Dense local attention (window=512) │
│ In higher layers: Dilated attention (dilation=2, window=512) │
│ │
│ This increases effective receptive field without increasing cost. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DILATION EXAMPLE: │
│ ───────────────── │
│ │
│ Non-dilated (d=1, w=4): │
│ Position 8 attends to: 5, 6, 7, 8, 9, 10, 11, 12 │
│ Receptive field: 8 consecutive tokens │
│ │
│ Dilated (d=2, w=4): │
│ Position 8 attends to: 2, 4, 6, 8, 10, 12, 14, 16 │
│ Receptive field: 16 positions (but only 8 tokens) │
│ │
│ Same cost, 2× receptive field! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MULTI-HEAD WITH DIFFERENT DILATIONS: │
│ ───────────────────────────────────── │
│ │
│ Head 0: dilation = 1 (local details) │
│ Head 1: dilation = 2 │
│ Head 2: dilation = 4 │
│ Head 3: dilation = 8 (long range) │
│ │
│ Different heads capture different scales simultaneously. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EFFECTIVE RECEPTIVE FIELD: │
│ ────────────────────────── │
│ │
│ After L layers with window w and dilation d: │
│ Receptive field = L × w × d │
│ │
│ Example: 12 layers, w=512, d varying 1-4: │
│ Lower layers (d=1): captures local patterns │
│ Higher layers (d=4): combines patterns from wider context │
│ Effective receptive field: ~16K+ tokens │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Longformer Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class LongformerAttention(nn.Module):
"""
Longformer-style attention with sliding window + global attention.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
window_size: int = 512,
global_tokens: int = 1, # Usually [CLS] token
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.window_size = window_size
self.global_tokens = global_tokens
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
# Separate projections for global attention
self.global_q_proj = nn.Linear(hidden_size, hidden_size)
self.global_k_proj = nn.Linear(hidden_size, hidden_size)
self.global_v_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_states, attention_mask=None):
batch_size, seq_len, _ = hidden_states.shape
# Compute Q, K, V for local attention
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape for multi-head attention
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)
# Local sliding window attention
local_output = self._sliding_window_attention(q, k, v)
# Global attention for first g tokens
if self.global_tokens > 0:
global_output = self._global_attention(
hidden_states, q, k, v
)
# Combine: global tokens use global attention,
# other tokens use local + attend to global
output = self._combine_local_global(
local_output, global_output, hidden_states
)
else:
output = local_output
return self.out_proj(output)
def _sliding_window_attention(self, q, k, v):
"""
Compute attention within sliding windows.
For efficiency, real implementations use custom CUDA kernels.
This is a simplified version for understanding.
"""
batch_size, seq_len, num_heads, head_dim = q.shape
w = self.window_size
output = torch.zeros_like(q)
for i in range(seq_len):
# Window bounds
start = max(0, i - w // 2)
end = min(seq_len, i + w // 2 + 1)
# Get Q for position i, K and V for window
q_i = q[:, i:i+1, :, :] # (batch, 1, heads, dim)
k_window = k[:, start:end, :, :] # (batch, window, heads, dim)
v_window = v[:, start:end, :, :]
# Compute attention
scores = torch.einsum('bqhd,bkhd->bhqk', q_i, k_window)
scores = scores / (head_dim ** 0.5)
attn = F.softmax(scores, dim=-1)
out_i = torch.einsum('bhqk,bkhd->bqhd', attn, v_window)
output[:, i:i+1, :, :] = out_i
return output.view(batch_size, seq_len, -1)
def _global_attention(self, hidden_states, q, k, v):
"""
Global attention for designated global tokens.
"""
batch_size, seq_len, _ = hidden_states.shape
g = self.global_tokens
# Global tokens (e.g., first g tokens)
global_hidden = hidden_states[:, :g, :]
# Global Q, K, V with separate projections
global_q = self.global_q_proj(global_hidden)
global_k = self.global_k_proj(hidden_states) # All tokens as keys
global_v = self.global_v_proj(hidden_states)
# Reshape
global_q = global_q.view(batch_size, g, self.num_heads, self.head_dim)
global_k = global_k.view(batch_size, seq_len, self.num_heads, self.head_dim)
global_v = global_v.view(batch_size, seq_len, self.num_heads, self.head_dim)
# Full attention for global tokens
scores = torch.einsum('bghd,bshd->bhgs', global_q, global_k)
scores = scores / (self.head_dim ** 0.5)
attn = F.softmax(scores, dim=-1)
global_output = torch.einsum('bhgs,bshd->bghd', attn, global_v)
return global_output.view(batch_size, g, -1)
Part III: BigBird
Adding Random Attention
BigBird (Zaheer et al., 2020) adds random attention patterns to the mix:
┌─────────────────────────────────────────────────────────────────────────┐
│ BIGBIRD ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ BIGBIRD COMBINES THREE PATTERNS: │
│ ───────────────────────────────── │
│ │
│ 1. Local sliding window attention │
│ 2. Global attention (select tokens) │
│ 3. Random attention (random token pairs) │
│ │
│ The random attention is key: it provides "shortcuts" through │
│ the sequence, ensuring information can flow between distant tokens. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ BIGBIRD PATTERN (simplified, 12 tokens): │
│ ───────────────────────────────────────── │
│ │
│ G = Global token L = Local window R = Random │
│ │
│ Keys │
│ 0 1 2 3 4 5 6 7 8 9 10 11 │
│ 0 G G G G G G G G G G G G ← Global token │
│ 1 G L L L □ □ □ R □ □ □ □ Q │
│ 2 G L L L L □ □ □ □ R □ □ u │
│ 3 G □ L L L L □ □ □ □ R □ e │
│ 4 G □ □ L L L L □ R □ □ □ r │
│ 5 G □ □ □ L L L L □ □ R □ i │
│ 6 G □ □ R □ L L L L □ □ □ e │
│ 7 G □ R □ □ □ L L L L □ □ s │
│ 8 G □ □ □ R □ □ L L L L □ │
│ 9 G R □ □ □ □ □ □ L L L L │
│ 10 G □ □ □ □ R □ □ □ L L L │
│ 11 G □ □ R □ □ □ □ □ □ L L │
│ │
│ G = Global (column 0): Everyone attends to global tokens │
│ L = Local: Diagonal band for nearby tokens │
│ R = Random: Sparse random connections (different per token) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY RANDOM ATTENTION? │
│ ───────────────────── │
│ │
│ Theory: Random graphs have small diameter. │
│ │
│ With only local attention: │
│ Token 0 → Token 1000 requires ~1000/window hops │
│ Information "diffuses" slowly. │
│ │
│ With random attention: │
│ Each token has "shortcuts" to random distant tokens. │
│ Like small-world networks: O(log n) hops to reach anywhere. │
│ │
│ This is theoretically important: BigBird can simulate full │
│ attention in O(log n) layers, preserving expressiveness. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMPLEXITY: │
│ ─────────── │
│ │
│ Window size w, r random connections, g global tokens: │
│ │
│ Per token: w (local) + r (random) + g (global) connections │
│ Total: O(n × (w + r + g)) │
│ │
│ Typically w=64, r=3, g=2: each token attends to ~70 others │
│ vs n for full attention. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
BigBird vs Longformer
┌─────────────────────────────────────────────────────────────────────────┐
│ BIGBIRD VS LONGFORMER │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ SIMILARITIES: │
│ ───────────── │
│ • Both use local sliding window attention │
│ • Both support global attention for special tokens │
│ • Both achieve O(n) complexity │
│ • Both support up to 4K-16K context │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DIFFERENCES: │
│ ──────────── │
│ │
│ Feature Longformer BigBird │
│ ─────────────────────────────────────────────────────────────── │
│ Random attn No Yes │
│ Dilated attn Yes (optional) No │
│ Theoretical Heuristic Proven expressive │
│ Implementation Custom kernels Block sparse │
│ Primary use Document understanding QA, summarization │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHEN TO USE WHICH: │
│ ────────────────── │
│ │
│ Longformer: │
│ • Document classification │
│ • When global tokens are sufficient for long-range │
│ • When you need dilated attention patterns │
│ │
│ BigBird: │
│ • Question answering over long documents │
│ • When theoretical expressiveness guarantees matter │
│ • When random attention helps your task │
│ │
│ In practice, both work well for long documents. │
│ Choice often comes down to available implementations. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part IV: Other Sparse Patterns
Block Sparse Attention
┌─────────────────────────────────────────────────────────────────────────┐
│ BLOCK SPARSE ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CONCEPT: │
│ ──────────── │
│ │
│ Divide sequence into blocks, attend to selected blocks. │
│ More GPU-efficient than arbitrary sparse patterns. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ BLOCK PATTERN (block size = 4): │
│ ──────────────────────────────── │
│ │
│ Block 0 Block 1 Block 2 Block 3 │
│ 0 1 2 3 4 5 6 7 8 9 A B C D E F │
│ B 0 ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ │
│ l 1 ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ │
│ o 2 ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ │
│ c 3 ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ ← attends to next block│
│ k ───────────────────────────────────── │
│ 0 4 □ □ □ □ ■ ■ ■ ■ □ □ □ □ □ □ □ □ │
│ 5 □ □ □ □ ■ ■ ■ ■ □ □ □ □ □ □ □ □ │
│ B 6 □ □ □ □ ■ ■ ■ ■ □ □ □ □ □ □ □ □ │
│ l 7 □ □ □ □ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ │
│ o ───────────────────────────────────── │
│ c 8 □ □ □ □ □ □ □ □ ■ ■ ■ ■ □ □ □ □ │
│ k 9 □ □ □ □ □ □ □ □ ■ ■ ■ ■ □ □ □ □ │
│ A □ □ □ □ □ □ □ □ ■ ■ ■ ■ □ □ □ □ │
│ 1 B □ □ □ □ □ □ □ □ ■ ■ ■ ■ ■ ■ ■ ■ │
│ │
│ Each block attends to itself + adjacent blocks. │
│ Matrix operations are on dense blocks → GPU efficient. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ GPU EFFICIENCY: │
│ ─────────────── │
│ │
│ Arbitrary sparsity: Scattered memory access, poor GPU utilization │
│ Block sparsity: Dense operations on blocks, near-optimal GPU use │
│ │
│ Block sparse can achieve 70-90% of dense attention throughput │
│ while reducing FLOPs by 10-100×. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Axial Attention
┌─────────────────────────────────────────────────────────────────────────┐
│ AXIAL ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CONCEPT: │
│ ──────────── │
│ │
│ For 2D data (images, tables), attend along axes separately. │
│ Instead of n² full attention, do 2 × n√n attention operations. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXAMPLE: 4×4 grid (16 tokens) │
│ ───────────────────────────── │
│ │
│ Original positions: │
│ 0 1 2 3 │
│ 4 5 6 7 │
│ 8 9 10 11 │
│ 12 13 14 15 │
│ │
│ Step 1: Row attention (each row attends within itself) │
│ Token 5 attends to: 4, 5, 6, 7 │
│ │
│ Step 2: Column attention (each column attends within itself) │
│ Token 5 attends to: 1, 5, 9, 13 │
│ │
│ Combined: Token 5 can reach any cell in 2 hops! │
│ Complexity: 2 × 16 × 4 = 128 vs 16² = 256 │
│ │
│ For n tokens in √n × √n grid: │
│ Axial: O(n√n) vs Full: O(n²) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ APPLICATIONS: │
│ ───────────── │
│ │
│ • Image processing (ViT variants) │
│ • Tabular data │
│ • Any data with natural 2D structure │
│ • Can extend to 3D (video) with 3 axial attentions │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Sliding Window with Sinks (Mistral's Approach)
┌─────────────────────────────────────────────────────────────────────────┐
│ SLIDING WINDOW WITH ATTENTION SINKS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ MISTRAL'S APPROACH: │
│ ─────────────────── │
│ │
│ Sliding window + always attend to first few tokens (attention sinks) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PATTERN (window=4, sinks=2): │
│ ───────────────────────────── │
│ │
│ Keys │
│ 0 1 2 3 4 5 6 7 8 9 10 11 │
│ 0 ■ □ □ □ □ □ □ □ □ □ □ □ Q │
│ 1 ■ ■ □ □ □ □ □ □ □ □ □ □ u │
│ 2 ■ ■ ■ □ □ □ □ □ □ □ □ □ e │
│ 3 ■ ■ ■ ■ □ □ □ □ □ □ □ □ r │
│ 4 ■ ■ □ ■ ■ □ □ □ □ □ □ □ i │
│ 5 ■ ■ □ □ ■ ■ □ □ □ □ □ □ e │
│ 6 ■ ■ □ □ □ ■ ■ □ □ □ □ □ s │
│ 7 ■ ■ □ □ □ □ ■ ■ □ □ □ □ │
│ 8 ■ ■ □ □ □ □ □ ■ ■ □ □ □ │
│ 9 ■ ■ □ □ □ □ □ □ ■ ■ □ □ │
│ 10 ■ ■ □ □ □ □ □ □ □ ■ ■ □ │
│ 11 ■ ■ □ □ □ □ □ □ □ □ ■ ■ │
│ │
│ Columns 0-1 (sinks): Always attended to │
│ Diagonal: Local sliding window │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY ATTENTION SINKS? │
│ ──────────────────── │
│ │
│ Observation: First tokens often receive high attention. │
│ This seems to be where models "dump" attention when uncertain. │
│ │
│ Without sinks: │
│ At position 1000 with window 512, position 0 is unreachable. │
│ Model may behave strangely when it can't access these sinks. │
│ │
│ With sinks: │
│ Position 0-1 always accessible, model behavior more stable. │
│ Minimal additional cost (just 2-4 extra attentions per token). │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ STREAMING INFERENCE: │
│ ──────────────────── │
│ │
│ With sliding window + sinks: │
│ • KV cache size fixed: sink_size + window_size │
│ • Can process infinite sequences! │
│ • Memory: O(1) instead of O(n) │
│ │
│ This enables Mistral's efficient long-context inference. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part V: Implementation Considerations
Efficient Sparse Attention
┌─────────────────────────────────────────────────────────────────────────┐
│ EFFICIENT IMPLEMENTATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ CHALLENGE: Making sparse attention fast on GPUs │
│ ───────────────────────────────────────────── │
│ │
│ GPUs are optimized for: │
│ • Dense matrix operations (GEMM) │
│ • Coalesced memory access │
│ • High arithmetic intensity │
│ │
│ Naive sparse attention violates all of these! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTION 1: Block-Sparse Operations │
│ ──────────────────────────────────── │
│ │
│ Organize sparsity into blocks. Within each block, compute dense. │
│ │
│ Libraries: │
│ • Triton (OpenAI): Block-sparse attention kernels │
│ • DeepSpeed: Sparse attention support │
│ • xFormers: Memory-efficient attention with sparsity │
│ │
│ Example with Triton: │
│ │
│ from triton.ops.blocksparse import matmul │
│ │
│ # Define block sparsity pattern │
│ layout = torch.tril(torch.ones(num_blocks, num_blocks)) │
│ layout[0, :] = 1 # Global row │
│ layout[:, 0] = 1 # Global column │
│ │
│ # Create block-sparse attention │
│ sparse_attn = BlockSparseAttention(layout, block_size=64) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTION 2: FlashAttention with Masks │
│ ───────────────────────────────────── │
│ │
│ FlashAttention supports attention masks. │
│ Some sparse patterns can be expressed as masks. │
│ │
│ # Sliding window mask │
│ def sliding_window_mask(seq_len, window_size): │
│ mask = torch.ones(seq_len, seq_len, dtype=torch.bool) │
│ for i in range(seq_len): │
│ start = max(0, i - window_size // 2) │
│ end = min(seq_len, i + window_size // 2 + 1) │
│ mask[i, start:end] = False # False = attend │
│ return mask │
│ │
│ # Use with FlashAttention │
│ output = flash_attn_func(q, k, v, attn_mask=mask) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTION 3: Reformulate as Smaller Dense Attentions │
│ ──────────────────────────────────────────────────── │
│ │
│ For sliding window: Chunk sequence, each chunk is dense. │
│ │
│ def chunked_attention(q, k, v, chunk_size): │
│ """Process in chunks with overlapping windows.""" │
│ outputs = [] │
│ for i in range(0, seq_len, chunk_size // 2): │
│ chunk_q = q[:, i:i+chunk_size] │
│ # Include overlap for context │
│ chunk_k = k[:, max(0,i-chunk_size//2):i+chunk_size] │
│ chunk_v = v[:, max(0,i-chunk_size//2):i+chunk_size] │
│ out = standard_attention(chunk_q, chunk_k, chunk_v) │
│ outputs.append(out) │
│ return combine_chunks(outputs) │
│ │
│ This uses standard optimized dense attention on smaller inputs. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Library Support
┌─────────────────────────────────────────────────────────────────────────┐
│ LIBRARY AND MODEL SUPPORT │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ HUGGING FACE TRANSFORMERS: │
│ ────────────────────────── │
│ │
│ Longformer: │
│ from transformers import LongformerModel │
│ model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
│ │
│ BigBird: │
│ from transformers import BigBirdModel │
│ model = BigBirdModel.from_pretrained("google/bigbird-roberta-base") │
│ │
│ Both support up to 4096 tokens out of the box. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ XFORMERS: │
│ ───────── │
│ │
│ from xformers.components.attention import ( │
│ ScaledDotProductAttention, │
│ AttentionMask, │
│ BlockSparseAttention, │
│ ) │
│ │
│ # Block sparse attention │
│ attn = BlockSparseAttention( │
│ layout=my_sparsity_pattern, │
│ block_size=64, │
│ ) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MODELS USING SPARSE ATTENTION: │
│ ────────────────────────────── │
│ │
│ Model Attention Type Max Context │
│ ───────────────────────────────────────────────────────────── │
│ Longformer Sliding + Global 4,096 │
│ BigBird Sliding + Random 4,096 │
│ Mistral 7B Sliding Window 32,768 (with RoPE) │
│ Mixtral Sliding Window 32,768 │
│ LED Longformer-style 16,384 │
│ LongT5 Local + Transient 16,384 │
│ │
│ Most modern long-context models combine: │
│ • Sparse attention (for efficiency) │
│ • RoPE scaling (for extrapolation) │
│ • FlashAttention (for memory efficiency) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VI: When to Use Sparse Attention
Decision Framework
┌─────────────────────────────────────────────────────────────────────────┐
│ WHEN TO USE SPARSE ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ USE SPARSE ATTENTION WHEN: │
│ ────────────────────────── │
│ │
│ ✓ Sequence length > 4K tokens │
│ ✓ Task doesn't need full all-to-all attention │
│ ✓ Most relevant context is local (nearby tokens) │
│ ✓ Global context can be captured by few special tokens │
│ │
│ GOOD USE CASES: │
│ ─────────────── │
│ • Document classification (Longformer with [CLS]) │
│ • Long document QA (BigBird) │
│ • Summarization of long texts │
│ • Code analysis (local context usually sufficient) │
│ • Chat with very long history │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DON'T USE SPARSE ATTENTION WHEN: │
│ ───────────────────────────────── │
│ │
│ ✗ Sequence length < 4K (full attention is fine) │
│ ✗ Task requires precise long-range dependencies │
│ ✗ You have GPUs with lots of memory │
│ ✗ FlashAttention solves your memory problem │
│ │
│ BETTER ALTERNATIVES: │
│ ──────────────────── │
│ • Short sequences: Just use FlashAttention │
│ • Need full attention: FlashAttention (O(n) memory, still O(n²) compute)
│ • Very long but can chunk: Chunk and process separately │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DECISION TREE: │
│ ────────────── │
│ │
│ Is sequence length > 4K? │
│ NO → Use standard attention + FlashAttention │
│ YES ↓ │
│ │
│ Does task need full attention? │
│ YES → Use FlashAttention (memory-efficient full attention) │
│ NO ↓ │
│ │
│ Is context mostly local? │
│ YES → Sliding window (Mistral-style) │
│ NO ↓ │
│ │
│ Need global context aggregation? │
│ YES → Longformer (sliding + global) │
│ MAYBE → BigBird (sliding + global + random) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VII: Recent Innovations (2024-2025)
DeepSeek Sparse Attention (DSA)
DeepSeek-V3 introduced a novel sparse attention pattern optimized for inference efficiency:
┌─────────────────────────────────────────────────────────────────────────┐
│ DEEPSEEK SPARSE ATTENTION (DSA) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE INSIGHT: │
│ ──────────── │
│ │
│ DeepSeek observed attention patterns during inference: │
│ • Recent tokens (last ~4K): High attention, need full detail │
│ • Middle tokens: Sparse attention to "landmark" positions │
│ • Beginning tokens: Attention sinks, always accessed │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DSA PATTERN: │
│ ──────────── │
│ │
│ For position p with context length L: │
│ │
│ Attend to: │
│ 1. Sink tokens: positions 0-127 (always) │
│ 2. Recent window: positions max(0, p-4096) to p (dense) │
│ 3. Landmarks: Every 64th token before the window │
│ │
│ Example at position 10000 (L=128K): │
│ • Sinks: 0-127 (128 tokens) │
│ • Recent: 5904-10000 (4096 tokens, dense) │
│ • Landmarks: 128, 192, 256, ..., 5888 (90 tokens) │
│ • Total: ~4,314 attentions vs 10,000 for full │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ IMPLEMENTATION DETAIL: │
│ ───────────────────── │
│ │
│ def get_dsa_attention_indices(position, context_len): │
│ indices = [] │
│ │
│ # 1. Sink tokens │
│ indices.extend(range(min(128, context_len))) │
│ │
│ # 2. Recent window (dense) │
│ window_start = max(128, position - 4096) │
│ indices.extend(range(window_start, position + 1)) │
│ │
│ # 3. Landmarks (every 64th token before window) │
│ for i in range(128, window_start, 64): │
│ indices.append(i) │
│ │
│ return sorted(set(indices)) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RESULTS: │
│ │
│ Context Length Full Attention DSA Attention Speedup │
│ ───────────────────────────────────────────────────────────── │
│ 32K ~32K tokens ~4.6K tokens 7× │
│ 128K ~128K tokens ~6.3K tokens 20× │
│ 256K ~256K tokens ~8.3K tokens 31× │
│ │
│ Quality: <0.5% perplexity degradation on benchmarks │
│ │
└─────────────────────────────────────────────────────────────────────────┘
MInference: Dynamic Sparse Patterns
MInference (Microsoft, 2024) discovers optimal sparse patterns dynamically:
┌─────────────────────────────────────────────────────────────────────────┐
│ MINFERENCE: LEARNED SPARSE PATTERNS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE APPROACH: │
│ ────────────── │
│ │
│ Instead of fixed patterns (window, global), learn optimal sparsity │
│ per attention head based on observed attention patterns. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DISCOVERED PATTERNS (from LLM attention analysis): │
│ ────────────────────────────────────────────────── │
│ │
│ 1. A-SHAPE: Attention to first tokens + recent tokens │
│ Common in: early layers, position-sensitive heads │
│ Pattern: [sink region] + [recent window] │
│ │
│ 2. VERTICAL-SLASH: Strong attention to specific columns │
│ Common in: middle layers, key-focused heads │
│ Pattern: Full attention to a few key positions │
│ │
│ 3. BLOCK-SPARSE: Attention in dense blocks │
│ Common in: all layers, local-focused heads │
│ Pattern: Block diagonal + scattered blocks │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ HEAD-SPECIFIC PATTERN SELECTION: │
│ ───────────────────────────────── │
│ │
│ MInference profiles attention patterns during a calibration run: │
│ │
│ for layer in model.layers: │
│ for head in layer.attention.heads: │
│ pattern = analyze_attention_distribution(head) │
│ if is_a_shape(pattern): │
│ head.sparse_config = AShapeSparse(sink=128, recent=4096) │
│ elif is_vertical(pattern): │
│ head.sparse_config = VerticalSparse(key_positions=[...]) │
│ else: │
│ head.sparse_config = BlockSparse(block_size=64) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RESULTS (Llama-3-70B, 128K context): │
│ │
│ Method Latency Memory Quality │
│ ───────────────────────────────────────────────────────── │
│ Full Attention Baseline Baseline Baseline │
│ Fixed Window 0.4× 0.3× Degraded │
│ MInference 0.5× 0.4× Near-lossless │
│ │
│ MInference achieves 2× speedup with minimal quality loss. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
SparseAttn: FlashAttention for Sparse Patterns
┌─────────────────────────────────────────────────────────────────────────┐
│ SPARSEATTN: EFFICIENT SPARSE KERNELS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CHALLENGE: │
│ ─────────────── │
│ │
│ Sparse attention saves FLOPs but naive implementations are slow: │
│ • Irregular memory access patterns │
│ • Can't use optimized dense GEMM kernels │
│ • Overhead from index management │
│ │
│ Result: Theoretical speedup ≠ actual speedup │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SPARSEATTN APPROACH: │
│ ──────────────────── │
│ │
│ Apply FlashAttention's tiling strategy to sparse patterns: │
│ │
│ 1. Divide sparse pattern into blocks │
│ 2. Identify which blocks are non-zero │
│ 3. Only compute attention for non-zero blocks │
│ 4. Use FlashAttention's memory-efficient tiling within blocks │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SUPPORTED PATTERNS: │
│ │
│ Pattern SparseAttn Support Speedup vs Dense │
│ ───────────────────────────────────────────────────────────── │
│ Sliding Window Yes 2-10× (seq dependent) │
│ Block Diagonal Yes Near-linear scaling │
│ Longformer-style Yes 3-8× │
│ BigBird-style Partial 2-5× │
│ Arbitrary Sparse Limited Depends on structure │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ USAGE: │
│ │
│ from sparseattn import sparse_flash_attention │
│ │
│ # Define sparse pattern as block mask │
│ block_mask = create_sliding_window_mask( │
│ seq_len=32768, │
│ block_size=256, │
│ window_blocks=16 # 4K window │
│ ) │
│ │
│ output = sparse_flash_attention( │
│ query, key, value, │
│ block_mask=block_mask, │
│ block_size=256 │
│ ) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
NSA: Native Sparse Attention
┌─────────────────────────────────────────────────────────────────────────┐
│ NATIVE SPARSE ATTENTION (NSA) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CONCEPT: │
│ ──────────── │
│ │
│ Train models with sparse attention from scratch, not as post-hoc. │
│ The model learns to work with limited attention budget. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ HYBRID DESIGN: │
│ ────────────── │
│ │
│ NSA uses different sparsity at different layers: │
│ │
│ Early layers (0-8): Local window + compression tokens │
│ • Local details need full resolution │
│ • Compress distant context into summary tokens │
│ │
│ Middle layers (8-24): Selective attention │
│ • Attend to positions based on learned selection │
│ • Top-k attention routing │
│ │
│ Late layers (24-32): Mixed local + global summaries │
│ • Local for recent context │
│ • Global summary tokens for distant context │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMPRESSION TOKENS: │
│ ─────────────────── │
│ │
│ Instead of attending to all distant tokens, compress them: │
│ │
│ Tokens 0-4096: Summarize into 64 "compression tokens" │
│ Tokens 4097-8192: Summarize into 64 compression tokens │
│ ... and so on │
│ │
│ New tokens attend to: │
│ • Recent 4K tokens (full resolution) │
│ • Compression tokens for older context (64× compression) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RESULTS (pre-trained from scratch): │
│ │
│ • 10× FLOPs reduction vs full attention │
│ • 95-98% performance retention on benchmarks │
│ • Native 128K context without quality degradation │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Updated Model Comparison
┌─────────────────────────────────────────────────────────────────────────┐
│ SPARSE ATTENTION MODELS (2025) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Model Pattern Context Key Innovation │
│ ───────────────────────────────────────────────────────────────────── │
│ Longformer Window + Global 4K Original sparse │
│ BigBird Window + Global + Rand 4K Theoretical basis │
│ Mistral Sliding + Sinks 32K Streaming KV │
│ Mixtral Sliding + Sinks 32K MoE + sparse │
│ Phi-3 Sliding window 128K Efficient small │
│ DeepSeek-V3 DSA (adaptive) 128K Dynamic landmarks │
│ Llama 3.1 Full + FlashAttention 128K Scaling, not sparse│
│ Gemini 1.5 Unknown (likely sparse) 1M+ Best long context │
│ │
│ INFERENCE-ONLY METHODS: │
│ • MInference: Dynamic per-head patterns │
│ • SparseAttn: Efficient sparse kernels │
│ • vLLM sparse: Production sparse inference │
│ │
│ TRAINABLE SPARSE (2025): │
│ • NSA: Dynamic hierarchical, 9× inference speedup │
│ • MoBA: Block-sparse with gating, 6.5× at 1M context │
│ • SeerAttention: Self-distilled gating, 5.67× vs FlashAttention-2 │
│ │
│ FLASHATTENTION-3 (2025): │
│ • 1.5-2× speedup on H100 vs FlashAttention-2 │
│ • 85% GPU utilization with BF16 (up from 35%) │
│ • 1.3 PFLOPs/s with FP8 low-precision │
│ • Warp-specialization + asynchronous tensor cores │
│ │
│ BLASST (2025 - Drop-in Sparse): │
│ • Dynamic attention pruning without pre-computation │
│ • Works with MHA, GQA, MQA, and MLA variants │
│ • Fits into existing FlashAttention kernel designs │
│ │
│ MORE 2025 METHODS: │
│ • PBS-Attn: Permuted block-sparse, 2.75× prefill speedup │
│ • InfLLM-V2: Dense-sparse switchable for short-to-long adaptation │
│ • AdaSplash: Adaptive sparse flash attention │
│ • LServe: Efficient long-sequence serving with unified sparse │
│ • GNA: Generalized Neighborhood Attention at speed of light │
│ │
│ TREND: Trainable sparse + FlashAttention-3 + dynamic pruning │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Summary
Sparse attention patterns enable efficient processing of long sequences by attending to subsets of tokens rather than all pairs:
┌─────────────────────────────────────────────────────────────────────────┐
│ KEY TAKEAWAYS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE PROBLEM: │
│ • Full attention is O(n²) in memory and compute │
│ • 32K context → 2TB attention memory (impossible) │
│ │
│ CORE PATTERNS: │
│ • Local/Sliding Window: Attend to nearby tokens O(n×w) │
│ • Global: Select tokens attend to everything O(n×g) │
│ • Random: Random connections for shortcuts O(n×r) │
│ • Block Sparse: Organize sparsity into dense blocks │
│ │
│ KEY MODELS: │
│ • Longformer: Local + Global + optional Dilated │
│ • BigBird: Local + Global + Random (theoretically complete) │
│ • Mistral: Sliding Window + Attention Sinks │
│ │
│ 2024-2025 INNOVATIONS: │
│ • NSA: Native Sparse Attention (ACL 2025 Best Paper) │
│ - 9× faster inference, 6× faster training on 64K sequences │
│ - Dynamic hierarchical: compression + selection + sliding window │
│ • MoBA: Mixture of Block Attention (Kimi, Feb 2025) │
│ - Block sparse with gating (like MoE for attention) │
│ - 6.5× speedup at 1M context, production-deployed at kimi.ai │
│ • SeerAttention: Microsoft NeurIPS 2025 │
│ - Learned sparse patterns via self-distillation │
│ - 90% sparsity, 5.67× speedup over FlashAttention-2 │
│ • DSA: Adaptive sparsity (sinks + window + landmarks) │
│ • MInference: Per-head learned sparse patterns │
│ │
│ PRACTICAL GUIDANCE: │
│ • < 4K tokens: Just use FlashAttention │
│ • > 4K, local context: Sliding window │
│ • > 4K, need global: Longformer or BigBird │
│ • > 32K inference: Consider MInference or DSA │
│ • Modern models: Combine sparse + RoPE scaling + FlashAttention │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Frequently Asked Questions
Related Articles
Attention Mechanisms: From Self-Attention to FlashAttention
A comprehensive deep dive into attention mechanisms—the core innovation powering modern LLMs. From the intuition behind self-attention to the engineering of FlashAttention, understand how transformers actually work.
Context Extension: How LLMs Scale Beyond Training Length
A comprehensive deep dive into context extension techniques—how models trained on 4K tokens extrapolate to 128K+. Understand RoPE scaling, Position Interpolation, NTK-aware scaling, YaRN, and the mathematics of long-context LLMs.
Transformer Architecture: A Complete Deep Dive
A comprehensive exploration of the transformer architecture—from embedding layers through attention and feed-forward networks to the output head. Understand why decoder-only models dominate, how residual connections enable deep networks, and the engineering decisions behind GPT, Llama, and modern LLMs.