Attention Mechanisms: From Self-Attention to FlashAttention
A comprehensive deep dive into attention mechanisms—the core innovation powering modern LLMs. From the intuition behind self-attention to the engineering of FlashAttention, understand how transformers actually work.
Table of Contents
Why Attention Matters
Every modern language model—GPT-4, Claude, Llama, Gemini—is built on attention. It's not hyperbole to say that attention is the most important algorithmic innovation in deep learning since backpropagation. Before attention, neural networks struggled with sequences: they could process them, but they couldn't effectively relate distant elements. A word at the beginning of a long paragraph had little influence on words at the end.
Attention changed everything. It allows every element in a sequence to directly interact with every other element, regardless of distance. This seemingly simple idea unlocked capabilities that were previously impossible: understanding context across thousands of words, generating coherent long-form text, and reasoning about relationships between concepts.
2025 FlashAttention-3 benchmarks: FlashAttention-3 achieves remarkable H100 utilization:
- BF16: Up to 840 TFLOPs/s (85% utilization)—compared to FlashAttention-2's 35% on H100
- FP8: Reaches 1.3 PFLOPs/s with 2.6x smaller error than baseline FP8
- 1.5-2x faster than FlashAttention-2, surpassing even cuDNN on medium/long sequences
Key techniques for Hopper GPUs: FlashAttention-3 exploits warp-specialization to separate producer (TMA) and consumer (WGMMA) warps, fully hiding load latency. It overlaps softmax of one block with the two GEMMs of the next, eliminating MUFU bottlenecks. Requires H100/H800 with CUDA ≥12.3 (CUDA 12.8 recommended).
But attention comes with a cost. The very property that makes it powerful—all-to-all interaction—creates computational challenges that have driven some of the most important engineering innovations of the past few years: FlashAttention, KV caching, Multi-Query Attention, Ring Attention, and more.
This post takes you from the fundamental intuition behind attention to the cutting-edge optimizations that make modern LLMs possible.
Part I: The Fundamentals
What is Attention? The Intuition
Before diving into equations, let's build intuition. Consider the sentence:
"The cat sat on the mat because it was tired."
What does "it" refer to? As humans, we instantly know "it" refers to "the cat," not "the mat." But how would a neural network figure this out?
Early approaches (RNNs, LSTMs) processed sequences left-to-right, maintaining a hidden state that accumulated information. By the time they reached "it," the information about "cat" had been compressed and mixed with everything else. The connection was fuzzy.
Attention takes a different approach: when processing "it," directly look back at all previous words and compute how relevant each one is. "Cat" scores high (it's an animate noun that can be tired), "mat" scores low (inanimate objects don't get tired), and "sat" scores medium (the action provides context).
┌─────────────────────────────────────────────────────────────────────────┐
│ ATTENTION INTUITION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE QUESTION ATTENTION ANSWERS: │
│ ──────────────────────────────── │
│ "When processing this position, how much should I pay attention │
│ to each other position?" │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ Sentence: "The cat sat on the mat because it was tired." │
│ │
│ When processing "it": │
│ │
│ "The" → low attention (0.02) [generic article] │
│ "cat" → HIGH attention (0.45) [animate, subject] │
│ "sat" → medium (0.15) [action context] │
│ "on" → low (0.03) [preposition] │
│ "the" → low (0.02) [generic article] │
│ "mat" → medium-low (0.08) [possible but unlikely referent] │
│ "because" → low (0.05) [connector] │
│ "it" → medium (0.20) [self-reference] │
│ │
│ The model "attends" mostly to "cat" when processing "it". │
│ This attention is learned from data, not hard-coded. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY THIS IS POWERFUL: │
│ │
│ 1. DIRECT CONNECTION: "it" directly accesses "cat" information │
│ No degradation through many layers of hidden states │
│ │
│ 2. LEARNED RELEVANCE: The model learns what to attend to │
│ Different tasks learn different attention patterns │
│ │
│ 3. PARALLEL COMPUTATION: All positions computed simultaneously │
│ Unlike RNNs that must process sequentially │
│ │
│ 4. INTERPRETABLE: We can visualize where the model "looks" │
│ Attention weights provide some insight into model behavior │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Queries, Keys, and Values: The Core Mechanism
The attention mechanism uses three learned transformations of each input: Queries (Q), Keys (K), and Values (V). This naming comes from information retrieval:
- Query: "What am I looking for?"
- Key: "What do I have to offer?"
- Value: "What information do I actually contain?"
Think of it like a library search:
- Your query is "books about cats"
- Each book has a key (its subject tags, title, description)
- The value is the actual content of the book
- You match your query against keys to find relevant values
┌─────────────────────────────────────────────────────────────────────────┐
│ QUERIES, KEYS, AND VALUES │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ INPUT: Sequence of token embeddings X = [x₁, x₂, ..., xₙ] │
│ Each xᵢ is a d-dimensional vector │
│ │
│ STEP 1: CREATE Q, K, V │
│ ────────────────────── │
│ For each token embedding xᵢ, compute: │
│ │
│ qᵢ = xᵢ × W_Q (Query: "What am I looking for?") │
│ kᵢ = xᵢ × W_K (Key: "What do I contain?") │
│ vᵢ = xᵢ × W_V (Value: "What information to extract?") │
│ │
│ W_Q, W_K, W_V are learned weight matrices (d × d_k, d × d_k, d × d_v) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY THREE SEPARATE TRANSFORMATIONS? │
│ ──────────────────────────────────── │
│ │
│ Using the same vector for matching and information extraction │
│ would be limiting. A token might be relevant for different reasons │
│ than what information we want to extract from it. │
│ │
│ Example: "The big red ball rolled quickly" │
│ │
│ When "quickly" queries for "ball": │
│ - KEY match: "ball" matches because it's the subject of the action │
│ - VALUE extracted: The physical properties (big, red, round) │
│ │
│ The key says "I'm relevant as a subject" │
│ The value says "here are my attributes" │
│ These are different kinds of information! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUALIZATION: │
│ │
│ x₁ x₂ x₃ x₄ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌──┴──┐ ┌──┴──┐ ┌──┴──┐ ┌──┴──┐ │
│ │ │ │ │ │ │ │ │ │
│ q₁ k₁ v₁ q₂ k₂ v₂ q₃ k₃ v₃ q₄ k₄ v₄ │
│ │
│ Each token produces its own query, key, and value vectors. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The Attention Computation
Once we have Q, K, V for all positions, the attention computation happens in three steps:
┌─────────────────────────────────────────────────────────────────────────┐
│ ATTENTION COMPUTATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ STEP 1: COMPUTE ATTENTION SCORES │
│ ───────────────────────────────── │
│ Score(qᵢ, kⱼ) = qᵢ · kⱼ (dot product) │
│ │
│ Higher dot product = more similar = higher attention │
│ │
│ For all pairs, this creates an n × n attention score matrix: │
│ │
│ k₁ k₂ k₃ k₄ │
│ ┌─────┬─────┬─────┬─────┐ │
│ q₁ │ 2.1 │ 0.5 │-0.3 │ 1.2 │ │
│ ├─────┼─────┼─────┼─────┤ │
│ q₂ │ 0.8 │ 3.2 │ 0.1 │-0.5 │ │
│ ├─────┼─────┼─────┼─────┤ │
│ q₃ │-0.2 │ 1.1 │ 2.8 │ 0.9 │ │
│ ├─────┼─────┼─────┼─────┤ │
│ q₄ │ 0.4 │ 0.3 │ 1.5 │ 2.4 │ │
│ └─────┴─────┴─────┴─────┘ │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ STEP 2: SCALE AND SOFTMAX │
│ ───────────────────────── │
│ │
│ Scale by √d_k (dimension of keys): │
│ scaled_score = score / √d_k │
│ │
│ WHY SCALE? │
│ For large d_k, dot products can become very large, pushing softmax │
│ into regions with extremely small gradients. Scaling keeps values │
│ in a reasonable range. │
│ │
│ If d_k = 64, we divide by √64 = 8 │
│ │
│ Apply softmax row-wise: │
│ attention_weights = softmax(scaled_scores) │
│ │
│ Softmax converts scores to probabilities (sum to 1): │
│ k₁ k₂ k₃ k₄ │
│ ┌──────┬──────┬──────┬──────┐ │
│ q₁ │ 0.45 │ 0.18 │ 0.08 │ 0.29 │ → sum = 1.0 │
│ ├──────┼──────┼──────┼──────┤ │
│ q₂ │ 0.12 │ 0.65 │ 0.11 │ 0.12 │ → sum = 1.0 │
│ ├──────┼──────┼──────┼──────┤ │
│ q₃ │ 0.07 │ 0.22 │ 0.51 │ 0.20 │ → sum = 1.0 │
│ ├──────┼──────┼──────┼──────┤ │
│ q₄ │ 0.14 │ 0.13 │ 0.30 │ 0.43 │ → sum = 1.0 │
│ └──────┴──────┴──────┴──────┘ │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ STEP 3: WEIGHTED SUM OF VALUES │
│ ─────────────────────────────── │
│ │
│ output_i = Σⱼ attention_weight(i,j) × vⱼ │
│ │
│ Each output is a weighted combination of all values, │
│ where weights indicate relevance. │
│ │
│ For position 1: │
│ output₁ = 0.45×v₁ + 0.18×v₂ + 0.08×v₃ + 0.29×v₄ │
│ │
│ Position 1 "attends" mostly to position 1 and 4, │
│ pulling information primarily from those values. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The Complete Formula: Scaled Dot-Product Attention
Putting it all together, the attention formula is:
Where:
- Q is the matrix of queries (n × d_k)
- K is the matrix of keys (n × d_k)
- V is the matrix of values (n × d_v)
- d_k is the dimension of queries/keys
- Q × K^T produces the n × n attention score matrix
- softmax normalizes each row to probabilities
- Final multiplication with V produces output (n × d_v)
┌─────────────────────────────────────────────────────────────────────────┐
│ ATTENTION FORMULA BREAKDOWN │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Attention(Q, K, V) = softmax(QK^T / √d_k) V │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DIMENSIONS (example with sequence length n=4, d_k=64, d_v=64): │
│ │
│ Q: (4 × 64) ──┐ │
│ ├──→ QK^T: (4 × 4) ──→ softmax: (4 × 4) │
│ K^T: (64 × 4) ──┘ │
│ │
│ softmax(4 × 4) × V(4 × 64) ──→ output: (4 × 64) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE QUADRATIC PROBLEM: │
│ ────────────────────── │
│ │
│ QK^T creates an n × n matrix. │
│ │
│ Memory: O(n²) │
│ Compute: O(n² × d) │
│ │
│ Sequence Length Attention Matrix Size Memory (FP16) │
│ ─────────────────────────────────────────────────────────── │
│ 512 tokens 262,144 elements 512 KB │
│ 2,048 tokens 4,194,304 elements 8 MB │
│ 8,192 tokens 67,108,864 elements 128 MB │
│ 32,768 tokens 1,073,741,824 elements 2 GB │
│ 131,072 tokens 17 billion elements 32 GB │
│ │
│ Per layer! A 32-layer model multiplies this by 32. │
│ This is why long contexts are expensive and why FlashAttention │
│ matters so much. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Self-Attention vs. Cross-Attention
There are two main types of attention, distinguished by where Q, K, V come from:
┌─────────────────────────────────────────────────────────────────────────┐
│ SELF-ATTENTION VS CROSS-ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ SELF-ATTENTION: │
│ ─────────────── │
│ Q, K, V all come from the same sequence. │
│ │
│ Input X ──→ Q = X × W_Q │
│ ──→ K = X × W_K │
│ ──→ V = X × W_V │
│ │
│ Each position attends to all positions in the same sequence. │
│ This is what GPT, Llama, and most LLMs use. │
│ │
│ Use case: Language modeling, where the model processes a single │
│ sequence and needs to understand relationships within it. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ CROSS-ATTENTION: │
│ ──────────────── │
│ Q comes from one sequence, K and V from another. │
│ │
│ Decoder X ──→ Q = X × W_Q │
│ Encoder Y ──→ K = Y × W_K │
│ ──→ V = Y × W_V │
│ │
│ Positions in X attend to positions in Y. │
│ │
│ Use cases: │
│ • Encoder-decoder models (T5, BART): Decoder attends to encoder │
│ • Vision-language models: Text attends to image features │
│ • Retrieval-augmented generation: Query attends to retrieved docs │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXAMPLE: Translation (Encoder-Decoder) │
│ │
│ Encoder input: "The cat sat" (English) │
│ Decoder generating: "Le chat" → predicting next word │
│ │
│ Cross-attention: "chat" (French for cat) attends to encoder │
│ Query from "chat" matches keys from "cat" strongly │
│ Value from "cat" position helps generate correct translation │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUAL REPRESENTATION: │
│ │
│ Self-Attention: Cross-Attention: │
│ (same sequence) (different sequences) │
│ │
│ A ←──→ B ←──→ C X₁ ───→ Y₁, Y₂, Y₃ │
│ ↑ ╲ ↑ ╱ ↑ X₂ ───→ Y₁, Y₂, Y₃ │
│ └───╲─┴─╱───┘ X₃ ───→ Y₁, Y₂, Y₃ │
│ ╲ ╱ │
│ ╳ X attends to Y │
│ ╱ ╲ (not to itself) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Masked Attention: Preventing Future Peeking
In autoregressive language models (GPT-style), we generate tokens one at a time, left to right. During training, we process the whole sequence at once for efficiency, but each position should only attend to previous positions—not future ones (that would be cheating!).
Masked attention enforces this by setting future attention scores to -∞ before softmax, which zeros out those attention weights:
┌─────────────────────────────────────────────────────────────────────────┐
│ CAUSAL (MASKED) ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE PROBLEM: │
│ ───────────── │
│ During training, we have the full sequence: │
│ "The cat sat on the mat" │
│ │
│ When predicting "sat", the model should only see "The cat" │
│ NOT "on the mat" (future tokens) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE SOLUTION: CAUSAL MASK │
│ ──────────────────────────── │
│ │
│ Before softmax, apply a mask that sets future positions to -∞: │
│ │
│ Raw attention scores: After masking: │
│ k₁ k₂ k₃ k₄ k₁ k₂ k₃ k₄ │
│ ┌────┬────┬────┬────┐ ┌────┬────┬────┬────┐ │
│ q₁ │ 2.1│ 0.5│-0.3│ 1.2│ q₁│ 2.1│ -∞ │ -∞ │ -∞ │ │
│ ├────┼────┼────┼────┤ ├────┼────┼────┼────┤ │
│ q₂ │ 0.8│ 3.2│ 0.1│-0.5│ q₂│ 0.8│ 3.2│ -∞ │ -∞ │ │
│ ├────┼────┼────┼────┤ ├────┼────┼────┼────┤ │
│ q₃ │-0.2│ 1.1│ 2.8│ 0.9│ q₃│-0.2│ 1.1│ 2.8│ -∞ │ │
│ ├────┼────┼────┼────┤ ├────┼────┼────┼────┤ │
│ q₄ │ 0.4│ 0.3│ 1.5│ 2.4│ q₄│ 0.4│ 0.3│ 1.5│ 2.4│ │
│ └────┴────┴────┴────┘ └────┴────┴────┴────┘ │
│ │
│ After softmax: │
│ k₁ k₂ k₃ k₄ │
│ ┌─────┬─────┬─────┬─────┐ │
│ q₁ │ 1.0 │ 0.0 │ 0.0 │ 0.0 │ (can only see position 1) │
│ ├─────┼─────┼─────┼─────┤ │
│ q₂ │ 0.08│ 0.92│ 0.0 │ 0.0 │ (can see positions 1-2) │
│ ├─────┼─────┼─────┼─────┤ │
│ q₃ │ 0.05│ 0.18│ 0.77│ 0.0 │ (can see positions 1-3) │
│ ├─────┼─────┼─────┼─────┤ │
│ q₄ │ 0.13│ 0.12│ 0.32│ 0.43│ (can see all positions) │
│ └─────┴─────┴─────┴─────┘ │
│ │
│ The triangular pattern ensures each position only attends to │
│ itself and previous positions—never future positions. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ IMPLEMENTATION: │
│ ─────────────── │
│ mask = torch.triu(torch.ones(n, n), diagonal=1).bool() │
│ scores = scores.masked_fill(mask, float('-inf')) │
│ attention_weights = F.softmax(scores, dim=-1) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part II: Multi-Head Attention
Why Multiple Heads?
A single attention head computes one set of attention patterns. But language has many types of relationships: syntactic (subject-verb agreement), semantic (meaning), positional (nearby words), referential (pronouns to nouns), and more.
Multi-head attention runs multiple attention operations in parallel, each with its own learned Q, K, V projections. Different heads can specialize in different types of relationships:
┌─────────────────────────────────────────────────────────────────────────┐
│ MULTI-HEAD ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ WHY MULTIPLE HEADS? │
│ ─────────────────── │
│ │
│ Single head: One attention pattern per position │
│ │
│ "The cat sat because it was tired" │
│ │
│ Single head might learn to attend to: │
│ "it" → "cat" (subject reference) │
│ │
│ But we lose other relationships: │
│ "tired" → "sat" (action that led to state) │
│ "was" → "it" (verb agreement) │
│ │
│ Multiple heads: Many attention patterns simultaneously │
│ │
│ Head 1: Syntactic relationships (subject-verb) │
│ Head 2: Semantic relationships (cause-effect) │
│ Head 3: Positional (local context) │
│ Head 4: Long-range dependencies │
│ ... │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE MECHANISM: │
│ ────────────── │
│ │
│ Instead of one attention with d-dimensional Q, K, V: │
│ Use h heads, each with d/h dimensional Q, K, V │
│ │
│ Example: d=512, h=8 heads → each head uses d_k = 64 │
│ │
│ Input X (n × 512) │
│ │ │
│ ┌────────┼────────┐ │
│ ▼ ▼ ▼ │
│ Head 1 Head 2 ... Head 8 │
│ (n × 64) (n × 64) (n × 64) │
│ │ │ │ │
│ └────────┴────────┘ │
│ │ │
│ Concatenate │
│ (n × 512) │
│ │ │
│ Linear W_O │
│ (n × 512) │
│ │ │
│ Output │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE FORMULA: │
│ ──────────── │
│ │
│ MultiHead(Q, K, V) = Concat(head₁, ..., head_h) × W_O │
│ │
│ where head_i = Attention(Q × W_Q^i, K × W_K^i, V × W_V^i) │
│ │
│ Each head has its own projection matrices W_Q^i, W_K^i, W_V^i │
│ W_O projects the concatenated outputs back to model dimension │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMPUTATIONAL COST: │
│ ─────────────────── │
│ │
│ Single head attention: O(n² × d) │
│ Multi-head (h heads): O(h × n² × d/h) = O(n² × d) │
│ │
│ Same total compute! We trade head dimension for number of heads. │
│ But we get h different attention patterns. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
What Do Different Heads Learn?
Research has shown that different heads often specialize:
┌─────────────────────────────────────────────────────────────────────────┐
│ HEAD SPECIALIZATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ OBSERVED PATTERNS IN TRAINED MODELS: │
│ ───────────────────────────────────── │
│ │
│ POSITIONAL HEADS: │
│ Attend to tokens at fixed relative positions │
│ • "Previous token" head: Always attends to position n-1 │
│ • "Local window" head: Attends to nearby positions │
│ │
│ [The] [cat] [sat] [on] [the] [mat] │
│ ↑ │
│ ←───┘ This head always looks at previous word │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SYNTACTIC HEADS: │
│ Attend based on grammatical relationships │
│ • Subject-verb heads │
│ • Noun-adjective heads │
│ • Dependency parsing heads │
│ │
│ [The] [big] [cat] [sat] [quietly] │
│ └────→────┘ │
│ Adjective head: "cat" attends to "big" │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SEMANTIC HEADS: │
│ Attend based on meaning relationships │
│ • Coreference: Pronouns to their referents │
│ • Semantic similarity: Related concepts │
│ │
│ [The] [cat] [...] [it] [was] [tired] │
│ └───────────→──┘ │
│ Coreference head: "it" attends strongly to "cat" │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SPECIAL PATTERN HEADS: │
│ • Delimiter heads: Attend to punctuation, sentence boundaries │
│ • Copy heads: Allow copying from input (important for names) │
│ • Induction heads: Pattern matching for in-context learning │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ HEAD REDUNDANCY: │
│ ──────────────── │
│ Not all heads are equally important. Studies show: │
│ • Many heads can be pruned with minimal quality loss │
│ • Some heads are redundant (similar patterns) │
│ • A few heads are critical—removing them hurts a lot │
│ │
│ This observation motivates MQA and GQA optimizations │
│ (sharing key-value heads to reduce memory). │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part III: Positional Encodings
The Position Problem
Attention treats its input as a set, not a sequence. The attention formula computes the same output regardless of token order:
Attention([A, B, C]) produces the same patterns as Attention([C, A, B])
But word order matters! "Dog bites man" and "Man bites dog" have very different meanings. We need a way to inject position information.
┌─────────────────────────────────────────────────────────────────────────┐
│ THE POSITION PROBLEM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ATTENTION IS PERMUTATION EQUIVARIANT: │
│ ────────────────────────────────────── │
│ │
│ If you shuffle the input positions, attention just shuffles the │
│ output in the same way. It doesn't inherently "know" that position │
│ 1 comes before position 2. │
│ │
│ Input: [The, cat, sat] → Attention → Output for each position │
│ Input: [sat, The, cat] → Attention → Same outputs, shuffled │
│ │
│ This is a feature (parallel processing) and a bug (loses order). │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY POSITION MATTERS: │
│ ───────────────────── │
│ │
│ "The dog chased the cat" vs "The cat chased the dog" │
│ ↑ ↑ │
│ Same words, different meanings based on position │
│ │
│ "Not bad" vs "Bad not" (word order creates negation patterns) │
│ │
│ Code: Order determines execution (x = 1; y = x; vs y = x; x = 1;) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTION: POSITIONAL ENCODINGS │
│ ─────────────────────────────── │
│ │
│ Add position information to token embeddings before attention. │
│ │
│ token_embedding + positional_encoding → input to attention │
│ │
│ Now the model can distinguish positions: │
│ • Position 1 always gets the same position encoding │
│ • Position 2 gets a different one │
│ • The model learns to use this information │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Sinusoidal Positional Encodings (Original Transformer)
The original Transformer used fixed sinusoidal functions to encode positions:
┌─────────────────────────────────────────────────────────────────────────┐
│ SINUSOIDAL POSITIONAL ENCODINGS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE FORMULA: │
│ ──────────── │
│ │
│ PE(pos, 2i) = sin(pos / 10000^(2i/d)) │
│ PE(pos, 2i+1) = cos(pos / 10000^(2i/d)) │
│ │
│ Where: │
│ • pos = position in sequence (0, 1, 2, ...) │
│ • i = dimension index (0 to d/2) │
│ • d = embedding dimension │
│ │
│ Each dimension uses a different frequency of sin/cos waves. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUALIZATION: │
│ │
│ Position → 0 1 2 3 4 5 6 7 │
│ │
│ Dim 0 (sin): ─ ╱ │ ╲ ─ ╱ │ ╲ [fast wave] │
│ Dim 1 (cos): │ ╲ ─ ╱ │ ╲ ─ ╱ │
│ Dim 2 (sin): ─ ─ ╱ ╱ │ │ ╲ ╲ [slower] │
│ Dim 3 (cos): │ │ ╲ ╲ ─ ─ ╱ ╱ │
│ ... │
│ Dim d (sin): ─────────────────╱───────────────── [very slow] │
│ │
│ Early dimensions: High frequency (change rapidly with position) │
│ Later dimensions: Low frequency (change slowly) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY THIS WORKS: │
│ ─────────────── │
│ │
│ 1. UNIQUE POSITIONS: Each position gets a unique encoding │
│ (like binary counting with continuous values) │
│ │
│ 2. RELATIVE POSITION: PE(pos+k) can be computed as a linear │
│ function of PE(pos). This helps the model learn relative │
│ position relationships. │
│ │
│ PE(pos+k) = R_k × PE(pos) where R_k is a rotation matrix │
│ │
│ 3. EXTRAPOLATION: Can extend to longer sequences than training │
│ (though performance degrades) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ LIMITATIONS: │
│ ──────────── │
│ • Fixed, not learned—can't adapt to task │
│ • Extrapolation works poorly in practice │
│ • Doesn't scale well to very long contexts │
│ │
│ Most modern models use learned or relative positional encodings. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Learned Positional Embeddings
A simpler approach: just learn the position embeddings like any other embedding:
┌─────────────────────────────────────────────────────────────────────────┐
│ LEARNED POSITIONAL EMBEDDINGS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE IDEA: │
│ ───────── │
│ Treat positions like a vocabulary. Learn an embedding for each. │
│ │
│ Position embedding matrix: E_pos ∈ R^(max_len × d) │
│ │
│ Position 0 → E_pos[0] (learned d-dimensional vector) │
│ Position 1 → E_pos[1] │
│ Position 2 → E_pos[2] │
│ ... │
│ Position max_len-1 → E_pos[max_len-1] │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ADVANTAGES: │
│ ─────────── │
│ ✓ Simple to implement │
│ ✓ Can learn task-specific patterns │
│ ✓ Often works as well as sinusoidal │
│ │
│ DISADVANTAGES: │
│ ────────────── │
│ ✗ Fixed maximum length (can't extrapolate) │
│ ✗ Positions beyond max_len are undefined │
│ ✗ No built-in notion of relative position │
│ │
│ Used by: GPT-2, BERT (original) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
RoPE: Rotary Position Embedding
RoPE is the most widely used positional encoding in modern LLMs (Llama, Qwen, Mistral, etc.). It's elegant, efficient, and extrapolates better than alternatives.
┌─────────────────────────────────────────────────────────────────────────┐
│ ROTARY POSITION EMBEDDING (RoPE) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE KEY INSIGHT: │
│ ──────────────── │
│ Instead of adding position to the embedding, encode position as │
│ a ROTATION applied to query and key vectors. │
│ │
│ The dot product between rotated q and k naturally encodes │
│ RELATIVE position—how far apart two tokens are. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE CLOCK ANALOGY: │
│ ────────────────── │
│ Imagine multiple clock hands, each rotating at different speeds. │
│ │
│ Fast hand Medium hand Slow hand │
│ (dim 0-1) (dim 2-3) (dim d-2, d-1) │
│ │
│ ┌────┐ ┌────┐ ┌────┐ │
│ │ / │ │ | │ │ | │ │
│ │/ │ │ | │ │ | │ │
│ └────┘ └────┘ └────┘ │
│ Pos 1 Pos 1 Pos 1 │
│ │
│ ┌────┐ ┌────┐ ┌────┐ │
│ │── │ │ / │ │ | │ │
│ │ │ │/ │ │ | │ │
│ └────┘ └────┘ └────┘ │
│ Pos 2 Pos 2 Pos 2 │
│ (rotated 90°) (rotated 45°) (barely moved) │
│ │
│ Each pair of dimensions is rotated by an angle that depends on: │
│ • Position (further positions = more rotation) │
│ • Dimension (early dims rotate faster than late dims) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE MATHEMATICS: │
│ ──────────────── │
│ │
│ For each pair of dimensions (2i, 2i+1), apply rotation: │
│ │
│ θ_i = 10000^(-2i/d) (angle frequency for dimension i) │
│ │
│ For position m, rotate the (x, y) pair by angle m × θ_i: │
│ │
│ [x'] [cos(mθ_i) -sin(mθ_i)] [x] │
│ [y'] = [sin(mθ_i) cos(mθ_i)] [y] │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY ROTATION ENCODES RELATIVE POSITION: │
│ ──────────────────────────────────────── │
│ │
│ When we compute q_m · k_n (query at position m, key at position n): │
│ │
│ Both vectors are rotated by their respective positions. │
│ The dot product depends on the DIFFERENCE in rotation angles. │
│ │
│ q_m · k_n = f(m - n) (depends only on relative position!) │
│ │
│ This is exactly what we want: attention should depend on how far │
│ apart tokens are, not their absolute positions. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ADVANTAGES OF RoPE: │
│ ─────────────────── │
│ ✓ Parameter-free (no learned position embeddings) │
│ ✓ Naturally encodes relative position │
│ ✓ Can extrapolate to longer sequences (with scaling tricks) │
│ ✓ Efficient implementation (element-wise operations) │
│ ✓ Works with FlashAttention │
│ │
│ Used by: Llama, Llama 2, Llama 3, Qwen, Mistral, Phi, and most │
│ modern open-source LLMs. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXTENDING CONTEXT LENGTH WITH RoPE: │
│ ──────────────────────────────────── │
│ │
│ RoPE can be "stretched" to handle longer contexts than training: │
│ │
│ 1. Linear scaling: Divide position by scale factor │
│ θ' = θ / scale │
│ Simple but degrades quality │
│ │
│ 2. NTK-aware scaling: Scale base frequency instead │
│ base' = base × scale^(d/(d-2)) │
│ Better preservation of short-range patterns │
│ │
│ 3. YaRN: Combined interpolation and attention scaling │
│ Best quality for extended context │
│ │
│ This is how models trained on 4K context can inference on 100K+. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
ALiBi: Attention with Linear Biases
An alternative approach that adds a bias to attention scores based on distance:
┌─────────────────────────────────────────────────────────────────────────┐
│ ALiBi (Attention with Linear Biases) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE IDEA: │
│ ───────── │
│ Don't modify embeddings at all. Instead, add a penalty to attention │
│ scores based on distance between positions. │
│ │
│ score(i, j) = q_i · k_j - m × |i - j| │
│ │
│ Where m is a head-specific slope (different for each head). │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUALIZATION: │
│ ────────────── │
│ Attention scores get penalized by distance: │
│ │
│ Position: 1 2 3 4 5 │
│ │
│ Query at 3: -2m -1m 0 -1m -2m │
│ (penalty added to raw attention scores) │
│ │
│ Close tokens get small penalty (high attention allowed) │
│ Far tokens get large penalty (attention suppressed) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ HEAD-SPECIFIC SLOPES: │
│ ───────────────────── │
│ Each head uses a different slope m, set as powers of 2: │
│ │
│ m_1 = 2^(-8/n) (small slope, attends far) │
│ m_2 = 2^(-16/n) (larger slope, attends closer) │
│ m_3 = 2^(-24/n) │
│ ... │
│ │
│ Some heads specialize in local context, others in global. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ADVANTAGES: │
│ ─────────── │
│ ✓ No learned parameters │
│ ✓ Extrapolates well to longer sequences │
│ ✓ Simple to implement (just add bias matrix) │
│ ✓ Can be combined with FlashAttention │
│ │
│ DISADVANTAGES: │
│ ────────────── │
│ ✗ Strong inductive bias (linear decay) │
│ ✗ Less flexible than RoPE for complex position patterns │
│ │
│ Used by: BLOOM, MPT, Falcon (some versions) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part IV: The KV Cache
Why Caching Matters
During autoregressive generation, we generate tokens one at a time. Each new token needs to attend to all previous tokens. Without caching, we'd recompute Q, K, V for all previous tokens at every step—massively wasteful.
┌─────────────────────────────────────────────────────────────────────────┐
│ THE KV CACHE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE PROBLEM: │
│ ───────────── │
│ │
│ Autoregressive generation: Generate one token at a time. │
│ │
│ Step 1: Input [A] → Compute attention → Output token B │
│ Step 2: Input [A, B] → Compute attention → Output token C │
│ Step 3: Input [A, B, C] → Compute attention → Output token D │
│ ... │
│ │
│ WITHOUT CACHING: │
│ At step 100, we have 100 tokens. │
│ We compute Q, K, V for ALL 100 tokens. │
│ But we only need the NEW token's Q! │
│ The K, V for positions 1-99 are the same as last step. │
│ │
│ Wasteful! Recomputing K, V for old tokens is pure overhead. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE SOLUTION: KV CACHE │
│ ────────────────────────── │
│ │
│ Cache the K and V vectors for all previous tokens. │
│ Only compute Q, K, V for the NEW token. │
│ Look up cached K, V for previous tokens. │
│ │
│ Step 1: Input [A] │
│ Compute K_A, V_A → Cache them │
│ Compute Q_A, attention, output B │
│ │
│ Step 2: Input [B] (just the new token!) │
│ Compute K_B, V_B → Add to cache │
│ Retrieve cached K_A, V_A │
│ Compute Q_B │
│ Attention: Q_B attends to [K_A, K_B], [V_A, V_B] │
│ Output C │
│ │
│ Step 3: Input [C] │
│ Compute K_C, V_C → Add to cache │
│ Retrieve cached K_A, K_B, V_A, V_B │
│ Compute Q_C │
│ Attention: Q_C attends to [K_A, K_B, K_C], [V_A, V_B, V_C] │
│ Output D │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SPEEDUP: │
│ ──────── │
│ Without cache: Each step processes all n tokens → O(n²) total │
│ With cache: Each step processes 1 new token → O(n) total │
│ │
│ For generating 1000 tokens, this is ~1000× less computation! │
│ │
└─────────────────────────────────────────────────────────────────────────┘
KV Cache Memory Requirements
The KV cache is often the memory bottleneck for long-context inference:
┌─────────────────────────────────────────────────────────────────────────┐
│ KV CACHE MEMORY CALCULATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ FORMULA: │
│ ──────── │
│ KV cache size per token = 2 × n_layers × n_heads × d_head × dtype │
│ │
│ • 2: We store both K and V │
│ • n_layers: Cache for each transformer layer │
│ • n_heads: Each head has its own K, V │
│ • d_head: Dimension of each head (typically d_model / n_heads) │
│ • dtype: Bytes per element (2 for FP16/BF16, 1 for INT8) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXAMPLE: Llama 2 7B │
│ ─────────────────── │
│ • n_layers = 32 │
│ • n_heads = 32 │
│ • d_head = 128 │
│ • dtype = 2 bytes (BF16) │
│ │
│ Per token: 2 × 32 × 32 × 128 × 2 = 524,288 bytes ≈ 512 KB │
│ │
│ For context lengths: │
│ • 2K tokens: 2 × 512KB = 1 GB │
│ • 4K tokens: 4 × 512KB = 2 GB │
│ • 8K tokens: 8 × 512KB = 4 GB │
│ • 32K tokens: 32 × 512KB = 16 GB │
│ • 128K tokens: 128 × 512KB = 64 GB │
│ │
│ This is PER REQUEST. Batch of 10 requests = 10× memory. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXAMPLE: Llama 3 70B │
│ ──────────────────── │
│ • n_layers = 80 │
│ • n_kv_heads = 8 (uses GQA!) │
│ • d_head = 128 │
│ │
│ Per token: 2 × 80 × 8 × 128 × 2 = 327,680 bytes ≈ 320 KB │
│ │
│ For 128K context: 128 × 320KB = 40 GB per request │
│ │
│ GQA reduces KV cache by 4× compared to full MHA! │
│ (8 KV heads vs 32 query heads) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ KV CACHE VS MODEL WEIGHTS: │
│ ────────────────────────── │
│ │
│ Model Weights KV Cache (4K) KV Cache (128K) │
│ ────────────────────────────────────────────────────────── │
│ Llama 7B 14 GB 2 GB 64 GB │
│ Llama 70B 140 GB ~3 GB* ~40 GB* │
│ │
│ *70B uses GQA, reducing KV cache size │
│ │
│ For long contexts, KV cache can exceed model weights! │
│ This is why KV cache optimization is critical. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
KV Cache Optimization Techniques
┌─────────────────────────────────────────────────────────────────────────┐
│ KV CACHE OPTIMIZATIONS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. MULTI-QUERY ATTENTION (MQA) │
│ ─────────────────────────────── │
│ Share a single K, V head across all query heads. │
│ │
│ Standard: 32 query heads × 32 KV heads │
│ MQA: 32 query heads × 1 KV head │
│ │
│ KV cache reduction: 32× │
│ Quality impact: Some degradation, especially for complex tasks │
│ │
│ Used by: PaLM, Falcon, StarCoder │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 2. GROUPED-QUERY ATTENTION (GQA) │
│ ───────────────────────────────── │
│ Compromise: Multiple query heads share each KV head. │
│ │
│ Example: 32 query heads, 8 KV heads │
│ Each KV head serves 4 query heads │
│ │
│ KV cache reduction: 4× │
│ Quality impact: Minimal—nearly matches full attention │
│ │
│ Used by: Llama 2 70B, Llama 3, Mistral, Mixtral │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 3. KV CACHE QUANTIZATION │
│ ───────────────────────── │
│ Store K, V in lower precision. │
│ │
│ FP16 → INT8: 2× reduction │
│ FP16 → INT4: 4× reduction │
│ FP16 → FP8: 2× reduction with less quality loss │
│ │
│ Modern inference engines (vLLM, TensorRT-LLM) support this. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 4. PAGED ATTENTION (vLLM) │
│ ────────────────────────── │
│ Problem: Continuous memory allocation wastes space. │
│ Different requests have different lengths; pre-allocating for │
│ maximum length wastes memory. │
│ │
│ Solution: Allocate KV cache in fixed-size "pages." │
│ Pages can be non-contiguous. │
│ Allocate pages as needed, reclaim when done. │
│ │
│ Result: Memory waste reduced from ~70% to <4% │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 5. KV CACHE OFFLOADING │
│ ────────────────────── │
│ Move inactive KV cache to CPU memory or SSD. │
│ Bring back when needed. │
│ │
│ Increases latency but enables longer contexts than GPU memory. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 6. SLIDING WINDOW │
│ ───────────────── │
│ Only attend to recent tokens (e.g., last 4K). │
│ KV cache has fixed maximum size. │
│ │
│ Used by: Mistral (4K sliding window) │
│ │
│ Trade-off: Can't use information from very old tokens. │
│ Mitigated by having some layers with global attention. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part V: FlashAttention
The Memory Bandwidth Problem
Standard attention has a hidden bottleneck: memory bandwidth, not compute.
┌─────────────────────────────────────────────────────────────────────────┐
│ THE MEMORY BANDWIDTH BOTTLENECK │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ GPU MEMORY HIERARCHY: │
│ ───────────────────── │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Registers (fastest) │ │
│ │ ~20,000 per SM │ │
│ ├─────────────────────────────────────────────────────────────────┤ │
│ │ SRAM / Shared Memory │ │
│ │ ~20 MB total (H100) │ │
│ │ ~19 TB/s bandwidth │ │
│ ├─────────────────────────────────────────────────────────────────┤ │
│ │ HBM (High Bandwidth Memory) │ │
│ │ 80 GB (H100) │ │
│ │ ~3 TB/s bandwidth │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ SRAM is ~6× faster than HBM, but ~4000× smaller. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ STANDARD ATTENTION MEMORY ACCESSES: │
│ ──────────────────────────────────── │
│ │
│ 1. Read Q, K from HBM │
│ 2. Compute S = QK^T │
│ 3. Write S to HBM (n × n matrix!) │
│ 4. Read S from HBM │
│ 5. Compute softmax │
│ 6. Write softmax result to HBM │
│ 7. Read softmax result and V from HBM │
│ 8. Compute output │
│ 9. Write output to HBM │
│ │
│ The n × n attention matrix is: │
│ • Written to HBM (step 3) │
│ • Read from HBM (step 4) │
│ • Written again after softmax (step 6) │
│ • Read again for value multiplication (step 7) │
│ │
│ 4 HBM accesses for each element of an n × n matrix! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE BOTTLENECK: │
│ ─────────────── │
│ │
│ Modern GPUs have massive compute (H100: 1979 TFLOPS for FP16) │
│ But memory bandwidth is limited (H100: 3.35 TB/s) │
│ │
│ Compute-to-memory ratio: ~590 FLOPS per byte │
│ │
│ Standard attention: │
│ • Does O(n² × d) compute │
│ • Reads/writes O(n²) data multiple times │
│ • Arithmetic intensity: ~10 FLOPS per byte │
│ │
│ We're using ~2% of available compute! │
│ The operation is MEMORY BOUND, not COMPUTE BOUND. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
FlashAttention: The Algorithm
FlashAttention's key insight: never materialize the full attention matrix in HBM. Compute attention in tiles that fit in SRAM.
┌─────────────────────────────────────────────────────────────────────────┐
│ FLASHATTENTION ALGORITHM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE KEY INSIGHT: │
│ ──────────────── │
│ We don't need the full n × n attention matrix at once. │
│ We can compute it in blocks, accumulating the output as we go. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE TILING APPROACH: │
│ ──────────────────── │
│ │
│ Divide Q, K, V into blocks that fit in SRAM: │
│ Q: [Q₁, Q₂, Q₃, ...] each block has B_q rows │
│ K: [K₁, K₂, K₃, ...] each block has B_k rows │
│ V: [V₁, V₂, V₃, ...] each block has B_k rows │
│ │
│ For each Q block: │
│ For each K, V block: │
│ 1. Load Q block, K block, V block into SRAM │
│ 2. Compute attention for this tile │
│ 3. Update running output (no HBM write!) │
│ Write final output to HBM │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ ATTENTION MATRIX │ │
│ │ │ │
│ │ K₁ K₂ K₃ K₄ K₅ │ │
│ │ ┌─────┬─────┬─────┬─────┬─────┐ │ │
│ │ │█████│ │ │ │ │ Q₁ × K_blocks │ │
│ │ ├─────┼─────┼─────┼─────┼─────┤ │ │
│ │ │ │█████│ │ │ │ Q₂ × K_blocks │ │
│ │ ├─────┼─────┼─────┼─────┼─────┤ │ │
│ │ │ │ │█████│ │ │ ... │ │
│ │ └─────┴─────┴─────┴─────┴─────┘ │ │
│ │ │ │
│ │ Process one tile (█) at a time in SRAM. │ │
│ │ Never store the full matrix in HBM. │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE ONLINE SOFTMAX TRICK: │
│ ───────────────────────── │
│ Problem: Softmax needs the full row to compute denominators. │
│ softmax(x)_i = exp(x_i) / Σ_j exp(x_j) │
│ │
│ Solution: Online (streaming) softmax algorithm. │
│ │
│ Track running statistics as we process blocks: │
│ • m: running maximum (for numerical stability) │
│ • l: running sum of exp(x - m) │
│ │
│ When processing new block: │
│ 1. Compute local max m_new for this block │
│ 2. Update global max: m = max(m, m_new) │
│ 3. Rescale previous sum: l = l × exp(m_old - m) │
│ 4. Add new values: l += Σ exp(x_block - m) │
│ 5. Update output accumulator with rescaling │
│ │
│ At the end, we have correct softmax without ever storing full row. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MEMORY COMPLEXITY: │
│ ────────────────── │
│ │
│ Standard attention: O(n²) - store full attention matrix │
│ FlashAttention: O(n) - only store tiles and accumulators │
│ │
│ For 32K sequence: 32K² × 2 bytes = 2 GB → ~65 KB (tiles) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
FlashAttention-2: Key Improvements
┌─────────────────────────────────────────────────────────────────────────┐
│ FLASHATTENTION-2 IMPROVEMENTS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ FLASHATTENTION-1 LIMITATIONS: │
│ ───────────────────────────── │
│ • ~35% GPU utilization on A100 │
│ • Non-optimal work partitioning │
│ • Unnecessary shared memory reads/writes │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ FLASHATTENTION-2 IMPROVEMENTS: │
│ ─────────────────────────────── │
│ │
│ 1. BETTER PARALLELISM │
│ FA1: Parallelize over batch and heads │
│ FA2: Also parallelize over sequence length │
│ │
│ Result: Better utilization of GPU SMs │
│ │
│ 2. REDUCED NON-MATMUL FLOPS │
│ FA1: Many rescaling operations in inner loop │
│ FA2: Rearranged algorithm to minimize non-matmul operations │
│ │
│ On A100, matmul throughput is 16× higher than non-matmul. │
│ Reducing non-matmul ops directly improves speed. │
│ │
│ 3. WORK PARTITIONING │
│ FA1: Each thread block handles one head │
│ FA2: Work partitioned across thread blocks more efficiently │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PERFORMANCE: │
│ ──────────── │
│ │
│ FlashAttention-1: ~35% utilization on A100 │
│ FlashAttention-2: ~70% utilization on A100 │
│ │
│ 2× speedup over FA1, 5-9× over standard attention. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ADDITIONAL FEATURES: │
│ ──────────────────── │
│ • Support for head dimensions up to 256 │
│ • Native support for MQA and GQA │
│ • Better support for variable sequence lengths │
│ │
└─────────────────────────────────────────────────────────────────────────┘
FlashAttention-3: Hopper Optimizations
FlashAttention-3 targets NVIDIA Hopper GPUs (H100) with hardware-specific optimizations:
┌─────────────────────────────────────────────────────────────────────────┐
│ FLASHATTENTION-3 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ HOPPER-SPECIFIC FEATURES LEVERAGED: │
│ ──────────────────────────────────── │
│ │
│ 1. TMA (Tensor Memory Accelerator) │
│ Hardware unit that automates data movement between HBM and SRAM. │
│ FA3 uses TMA to pre-load tiles, reducing kernel complexity. │
│ 30-40% throughput improvement from TMA alone. │
│ │
│ 2. WGMMA (Warp Group Matrix Multiply-Accumulate) │
│ New instruction for matrix multiplication on Hopper. │
│ Asynchronous execution—can overlap with other operations. │
│ │
│ 3. FP8 Support │
│ Hopper has native FP8 tensor cores. │
│ 2× compute throughput compared to FP16. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THREE KEY TECHNIQUES: │
│ ───────────────────── │
│ │
│ 1. WARP SPECIALIZATION (Producer-Consumer Asynchrony) │
│ Split warps into producers (load data) and consumers (compute). │
│ Producers use TMA to load next tile while consumers compute. │
│ Hides memory latency behind computation. │
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Producer warps: Load tile N+1 ──→ Load tile N+2 ──→ ... │ │
│ │ ↓ ↓ │ │
│ │ Consumer warps: Compute tile N ──→ Compute tile N+1 ──→ │ │
│ │ (overlapped!) │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
│ 2. SOFTMAX-MATMUL OVERLAP │
│ FA2: softmax completes before next GEMM starts │
│ FA3: softmax runs in parallel with GEMM using different units │
│ │
│ WGMMA (matrix multiply) is asynchronous. │
│ While waiting for WGMMA, compute softmax on different warps. │
│ │
│ 3. FP8 WITH INCOHERENT PROCESSING │
│ Problem: FP8 has limited range, outliers cause errors. │
│ Solution: Apply Hadamard transform to "spread" outliers. │
│ │
│ Result: FP8 attention with 2.6× lower error than naive FP8. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PERFORMANCE: │
│ ──────────── │
│ │
│ FA2 on H100: ~35-50% utilization │
│ FA3 on H100: ~75-85% utilization │
│ │
│ BF16: Up to 840 TFLOPS (85% of peak) │
│ FP8: Up to 1.3 PFLOPS │
│ │
│ 1.5-2× speedup over FA2 on H100. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ REQUIREMENTS: │
│ ───────────── │
│ • NVIDIA Hopper GPU (H100) │
│ • CUDA 12.3+ (12.8 recommended) │
│ • Available at: github.com/Dao-AILab/flash-attention │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VI: Long Context Solutions
The Long Context Challenge
Modern applications demand long contexts: entire codebases, full documents, multi-hour conversations. But attention's O(n²) complexity makes this expensive.
┌─────────────────────────────────────────────────────────────────────────┐
│ LONG CONTEXT SCALING │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ QUADRATIC SCALING REALITY: │
│ ────────────────────────── │
│ │
│ Context Attention FLOPs Memory (FP16) Time (A100) │
│ ────────────────────────────────────────────────────────────────── │
│ 4K 16 million 32 MB ~1 ms │
│ 8K 64 million 128 MB ~4 ms │
│ 32K 1 billion 2 GB ~60 ms │
│ 128K 16 billion 32 GB ~1 sec │
│ 512K 262 billion 512 GB ~16 sec │
│ 1M 1 trillion 2 TB ~64 sec │
│ │
│ 4× longer context = 16× more compute and memory. │
│ │
│ At 1M tokens, attention alone takes over a minute! │
│ (Times are rough estimates for single attention layer) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ APPROACHES TO LONG CONTEXT: │
│ ─────────────────────────── │
│ │
│ 1. Better algorithms: FlashAttention (same complexity, less memory) │
│ 2. Sparse attention: O(n × k) where k << n │
│ 3. Distributed attention: Split across multiple devices │
│ 4. Approximate attention: Linear complexity approximations │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Ring Attention: Distributed Long Context
Ring Attention distributes attention computation across multiple devices, enabling context lengths that scale with device count:
┌─────────────────────────────────────────────────────────────────────────┐
│ RING ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE IDEA: │
│ ───────── │
│ Distribute the sequence across devices in a ring. │
│ Each device holds a chunk of Q, K, V. │
│ Rotate K, V around the ring while computing attention. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RING TOPOLOGY: │
│ ─────────────── │
│ │
│ Device 0 ←───→ Device 1 │
│ ↑ ↓ │
│ Device 3 ←───→ Device 2 │
│ │
│ Each device: │
│ • Holds Q chunk (stays local) │
│ • Holds K, V chunks (rotate around ring) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE ALGORITHM: │
│ ─────────────── │
│ │
│ Sequence split into 4 chunks: [Q₀K₀V₀, Q₁K₁V₁, Q₂K₂V₂, Q₃K₃V₃] │
│ │
│ Initial state: │
│ Device 0: Q₀, K₀, V₀ │
│ Device 1: Q₁, K₁, V₁ │
│ Device 2: Q₂, K₂, V₂ │
│ Device 3: Q₃, K₃, V₃ │
│ │
│ Step 1: Each device computes attention with its local K, V │
│ Device 0: Attention(Q₀, K₀, V₀) │
│ Device 1: Attention(Q₁, K₁, V₁) │
│ ... │
│ │
│ Step 2: Rotate K, V to next device │
│ Device 0 receives K₃, V₃ (sends K₀, V₀ to Device 1) │
│ Device 0: Attention(Q₀, K₃, V₃) + accumulate │
│ │
│ Step 3: Rotate again │
│ Device 0 receives K₂, V₂ │
│ Device 0: Attention(Q₀, K₂, V₂) + accumulate │
│ │
│ Step 4: Rotate again │
│ Device 0 receives K₁, V₁ │
│ Device 0: Attention(Q₀, K₁, V₁) + accumulate │
│ │
│ After 4 steps, each device has computed full attention for its Q. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ KEY INSIGHT: OVERLAP COMMUNICATION AND COMPUTE │
│ ─────────────────────────────────────────────── │
│ While computing attention for current K, V block, │
│ asynchronously send/receive next K, V block. │
│ │
│ Communication is hidden behind computation. │
│ Near-perfect scaling with device count. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SCALING: │
│ ──────── │
│ Single device: max context = what fits in memory │
│ Ring Attention: max context = device_count × single_device_context │
│ │
│ 32 devices with 32K each → 1M context length │
│ │
│ Used by: PyTorch TorchTitan (Context Parallel), various research │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Sparse and Sliding Window Attention
Instead of attending to all positions, attend to a subset:
┌─────────────────────────────────────────────────────────────────────────┐
│ SPARSE ATTENTION PATTERNS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ SLIDING WINDOW ATTENTION (Mistral): │
│ ──────────────────────────────────── │
│ Each position only attends to W previous positions. │
│ │
│ Full attention: Sliding window (W=3): │
│ ┌─────────────┐ ┌─────────────┐ │
│ │█████████████│ │█ │ │
│ │██████████████│ │██ │ │
│ │███████████████│ │███ │ │
│ │████████████████│ │ ███ │ │
│ │█████████████████│ │ ███ │ │
│ └─────────────────┘ └────────────┘ │
│ │
│ Complexity: O(n²) → O(n × W) │
│ │
│ Used by: Mistral (W=4096) │
│ Information flow: Local patterns handled by window, │
│ global patterns emerge through stacking layers. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ LONGFORMER PATTERN: │
│ ─────────────────── │
│ Sliding window + global tokens (attend to/from all positions). │
│ │
│ ┌─────────────────────────────────────┐ │
│ │█ █ █ █ █ █ █ █│ Global tokens (row) │
│ │███ │ Sliding window │
│ │████ │ │
│ │ ████ │ │
│ │ ████ │ │
│ │█ █ █ █ █ █ █ █│ Global tokens (col) │
│ └─────────────────────────────────────┘ │
│ │
│ Useful for tasks with important anchor tokens (e.g., [CLS] in BERT). │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ BIGBIRD PATTERN: │
│ ──────────────── │
│ Sliding window + global + random attention. │
│ │
│ • Local: Attend to nearby tokens │
│ • Global: Some tokens attend to all │
│ • Random: Random connections for global information flow │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TRADE-OFFS: │
│ ─────────── │
│ Sparse attention is faster and uses less memory. │
│ But it can miss important long-range dependencies. │
│ │
│ Modern consensus: Use full attention when possible (with │
│ FlashAttention), sparse attention for extreme lengths. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Summary
Attention is both elegant and computationally demanding. Understanding it deeply reveals why modern LLMs work the way they do:
┌─────────────────────────────────────────────────────────────────────────┐
│ KEY TAKEAWAYS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE FUNDAMENTALS: │
│ ───────────────── │
│ • Attention allows all-to-all interaction in sequences │
│ • Q, K, V transform inputs for matching and information extraction │
│ • Softmax normalizes scores to attention weights │
│ • Complexity is O(n²) in sequence length │
│ │
│ MULTI-HEAD ATTENTION: │
│ ───────────────────── │
│ • Multiple heads capture different relationship types │
│ • Same compute as single head (trade dimension for heads) │
│ • Heads often specialize (positional, syntactic, semantic) │
│ │
│ POSITIONAL ENCODINGS: │
│ ────────────────────── │
│ • Attention is permutation-equivariant—needs position info │
│ • RoPE: Rotations encode relative position (most common today) │
│ • ALiBi: Linear bias by distance (simpler alternative) │
│ │
│ KV CACHE: │
│ ───────── │
│ • Essential for efficient autoregressive generation │
│ • Memory scales linearly with context length │
│ • MQA/GQA reduce KV cache size by sharing heads │
│ │
│ FLASHATTENTION: │
│ ─────────────── │
│ • Tiles computation to avoid materializing n×n matrix │
│ • Memory: O(n²) → O(n), Speed: 2-9× faster │
│ • FA3 achieves 85% GPU utilization on H100 │
│ │
│ LONG CONTEXT: │
│ ───────────── │
│ • Ring Attention: Distribute across devices │
│ • Sparse attention: Attend to subset of positions │
│ • Context length scales with engineering, not just algorithms │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Frequently Asked Questions
Related Articles
LLM Pre-training: Building Foundation Models from Scratch
A comprehensive guide to pre-training large language models—from data curation and architecture decisions to scaling laws and distributed training infrastructure. Understanding how GPT, Llama, and other foundation models are built.
LLM Inference Optimization: From Quantization to Speculative Decoding
A comprehensive guide to optimizing LLM inference for production—covering quantization, attention optimization, batching strategies, and deployment frameworks.
Mastering LLM Context Windows: Strategies for Long-Context Applications
Practical techniques for managing context windows in production LLM applications—from compression to hierarchical processing to infinite context architectures.
vLLM in Production: The Complete Guide to High-Performance LLM Serving
A comprehensive guide to deploying vLLM in production—covering architecture internals, configuration tuning, Kubernetes deployment, monitoring, and troubleshooting.
Open-Source LLMs: The Complete 2025 Guide
A comprehensive guide to open-source LLMs—Llama 4, Qwen3, DeepSeek V3.2, Mistral Large 3, Kimi K2, GLM-4.7 and more. Detailed benchmarks, hardware requirements, deployment strategies, and practical recommendations for production use.