Transformer Architecture: A Complete Deep Dive
A comprehensive exploration of the transformer architecture—from embedding layers through attention and feed-forward networks to the output head. Understand why decoder-only models dominate, how residual connections enable deep networks, and the engineering decisions behind GPT, Llama, and modern LLMs.
Table of Contents
Why Understanding Transformer Architecture Matters
Every modern language model—GPT-4, Claude, Llama, Gemini, Mistral—is built on the transformer architecture. While our attention mechanisms post covered how attention works in detail, this post zooms out to show how attention fits into the complete picture: the full transformer architecture from input tokens to output probabilities.
The transformer architecture, introduced in the 2017 paper "Attention Is All You Need," represented a fundamental departure from previous sequence modeling approaches. Before transformers, the field relied heavily on recurrent neural networks (RNNs) and their variants like LSTMs and GRUs. These architectures processed sequences one element at a time, maintaining a hidden state that accumulated information as they moved through the sequence. While elegant in concept, this sequential processing created two major problems: it was inherently slow because each step depended on the previous one, and it struggled to maintain information across long sequences due to the vanishing gradient problem.
The transformer solved both problems with a radically different approach: process all positions simultaneously using attention mechanisms to model relationships between any pair of positions directly. This parallel processing made transformers dramatically faster to train on modern GPUs, and the direct connections between positions eliminated the information bottleneck of sequential processing.
Understanding the architecture deeply reveals several crucial insights:
Why certain design choices work: The evolution from the original transformer to modern LLMs like Llama and GPT-4 involved numerous refinements. Pre-norm vs post-norm layer placement, SwiGLU vs ReLU activations, RoPE vs learned positional embeddings—each choice has specific tradeoffs that become clear only when you understand how information flows through the network and how gradients propagate during training.
Where compute and memory go: A common misconception is that attention dominates transformer computation. In reality, for most model sizes, the feed-forward network (FFN) consumes approximately two-thirds of the parameters and a significant portion of computation. Understanding this distribution helps you make informed decisions about optimization strategies and model scaling.
How to debug and optimize: When training fails or inference is slow, architectural knowledge helps you diagnose problems. Is the model experiencing gradient issues? Are certain components bottlenecking memory? Which optimizations will have the most impact?
Why decoder-only won: The original transformer had both encoder and decoder stacks. Yet modern LLMs almost universally use decoder-only architectures. This wasn't obvious in 2017—understanding why it happened reveals deep insights about how these models learn and what makes them effective for general-purpose language tasks.
This post is your complete guide to the transformer decoder block—the building block of all modern LLMs. We'll trace the path from input tokens through every component to output predictions, explaining not just what each piece does but why it's designed that way.
Part I: The Big Picture
From Text to Predictions: The Complete Journey
Before diving into individual components, let's understand the complete journey that transforms a text prompt into a prediction for the next token. This high-level view provides context for everything that follows.
When you send a prompt like "The capital of France is" to an LLM, a remarkable sequence of transformations occurs. First, the text is broken into tokens—subword units that the model understands. These tokens are converted to integer IDs based on the model's vocabulary. For our example, this might produce something like [464, 3797, 3332, 319, 262], where each number corresponds to a specific token in the vocabulary.
These integer IDs then pass through an embedding layer, which converts each ID into a dense vector of floating-point numbers. If the model has a hidden dimension of 4096, each token becomes a 4096-dimensional vector. These vectors aren't random—they're learned during training to capture semantic relationships between tokens. Similar tokens (like "cat" and "dog") will have similar vectors, while unrelated tokens will be far apart in this high-dimensional space.
Next, positional information is injected. The embedding vectors carry information about what each token means, but not where it appears in the sequence. Since transformers process all positions in parallel, they have no inherent notion of order. Position encodings add this crucial information, allowing the model to distinguish "dog bites man" from "man bites dog."
The positioned embeddings then pass through a stack of transformer blocks—typically 32 to 96 of them in modern LLMs. Each block performs two main operations: multi-head self-attention (which allows positions to exchange information) and a feed-forward network (which transforms the information at each position). Both operations use residual connections to preserve information flow and layer normalization to stabilize activations.
After all transformer blocks, a final layer normalization prepares the hidden states for the output layer. The output projection (often called the "LM head") converts each position's hidden state into a probability distribution over the entire vocabulary. For our example, this produces probabilities for every possible next token: P("Paris") = 0.85, P("Lyon") = 0.03, P("the") = 0.02, and so on across 32,000 or more vocabulary entries.
┌─────────────────────────────────────────────────────────────────────────┐
│ TRANSFORMER LLM: COMPLETE ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ INPUT: "The cat sat on the" │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ TOKENIZER │ │
│ │ "The cat sat on the" → [464, 3797, 3332, 319, 262] │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ TOKEN EMBEDDING │ │
│ │ [464, 3797, ...] → [[0.1, -0.2, ...], [0.3, 0.1, ...], ...] │ │
│ │ (vocab_size × d_model) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ POSITIONAL ENCODING │ │
│ │ Add position information (RoPE, learned, or sinusoidal) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ╔═════════════════════════════════════════════════════════════════╗ │
│ ║ ║ │
│ ║ TRANSFORMER BLOCK × N_LAYERS ║ │
│ ║ ║ │
│ ║ ┌───────────────────────────────────────────────────────────┐ ║ │
│ ║ │ LAYER NORM │ ║ │
│ ║ └───────────────────────────────────────────────────────────┘ ║ │
│ ║ │ ║ │
│ ║ ▼ ║ │
│ ║ ┌───────────────────────────────────────────────────────────┐ ║ │
│ ║ │ MULTI-HEAD SELF-ATTENTION │ ║ │
│ ║ │ (with causal mask for autoregressive) │ ║ │
│ ║ └───────────────────────────────────────────────────────────┘ ║ │
│ ║ │ ║ │
│ ║ ▼ ║ │
│ ║ ┌───────────────────────────────────────────────────────────┐ ║ │
│ ║ │ RESIDUAL CONNECTION │ ║ │
│ ║ │ x = x + attention(x) │ ║ │
│ ║ └───────────────────────────────────────────────────────────┘ ║ │
│ ║ │ ║ │
│ ║ ▼ ║ │
│ ║ ┌───────────────────────────────────────────────────────────┐ ║ │
│ ║ │ LAYER NORM │ ║ │
│ ║ └───────────────────────────────────────────────────────────┘ ║ │
│ ║ │ ║ │
│ ║ ▼ ║ │
│ ║ ┌───────────────────────────────────────────────────────────┐ ║ │
│ ║ │ FEED-FORWARD NETWORK (FFN/MLP) │ ║ │
│ ║ │ (expand → activate → project) │ ║ │
│ ║ └───────────────────────────────────────────────────────────┘ ║ │
│ ║ │ ║ │
│ ║ ▼ ║ │
│ ║ ┌───────────────────────────────────────────────────────────┐ ║ │
│ ║ │ RESIDUAL CONNECTION │ ║ │
│ ║ │ x = x + ffn(x) │ ║ │
│ ║ └───────────────────────────────────────────────────────────┘ ║ │
│ ║ ║ │
│ ╚═════════════════════════════════════════════════════════════════╝ │
│ │ │
│ ▼ (repeat N times) │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ FINAL LAYER NORM │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ OUTPUT PROJECTION (LM HEAD) │ │
│ │ hidden_state (d_model) → logits (vocab_size) │ │
│ │ Often tied with embedding weights │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ SOFTMAX │ │
│ │ logits → probabilities over vocabulary │ │
│ │ P("mat") = 0.42, P("floor") = 0.15, P("bed") = 0.08, ... │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ OUTPUT: Next token prediction ("mat" with p=0.42) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Understanding Parameter Distribution: Where the Model Lives
One of the most important yet misunderstood aspects of transformer architecture is where the parameters actually reside. When we say a model has "7 billion parameters," what does that actually mean? Where are those billions of numbers, and what do they do?
Understanding parameter distribution matters for several practical reasons. First, it tells you what to optimize: if you're trying to reduce model size or speed up inference, you should focus on the components with the most parameters. Second, it reveals the model's computational profile: larger components generally require more compute. Third, it helps you understand scaling behavior: how does the parameter distribution change as models grow from 7B to 70B to 700B?
Let's trace through a concrete example: Llama 2 7B, a well-documented open model with 32 transformer layers, a hidden dimension of 4096, 32 attention heads, and a vocabulary of 32,000 tokens.
The embedding layer creates the model's vocabulary representation. With 32,000 vocabulary entries and 4096 dimensions per entry, this matrix contains 32,000 × 4,096 = 131 million parameters. This seems like a lot, but it's only about 2% of the total model. Importantly, in most modern LLMs including Llama, this embedding matrix is "tied" with the output projection—the same weights are used for both input embeddings and output predictions. This weight tying is both a regularization technique and a parameter efficiency technique, but it means we only count these parameters once.
The transformer blocks contain the vast majority of parameters. Each block has two major components: multi-head attention and the feed-forward network.
The attention mechanism in each layer requires four weight matrices: W_Q, W_K, W_V for projecting inputs into queries, keys, and values, and W_O for projecting the attention output back to the hidden dimension. Each of these matrices has shape (4096 × 4096), contributing 16.8 million parameters each, for a total of 67.1 million parameters per layer for attention.
The feed-forward network in Llama uses the SwiGLU architecture, which requires three weight matrices rather than the two in the original transformer. These matrices project from the hidden dimension (4096) to an intermediate dimension (11008) and back. With three matrices of sizes (4096 × 11008), (4096 × 11008), and (11008 × 4096), the FFN contributes 135.3 million parameters per layer.
Notice something striking: the FFN has approximately twice as many parameters as attention in each layer! This ratio holds across most transformer architectures. The feed-forward network, despite being conceptually simpler than attention, is actually the parameter-heavy component.
Layer normalization adds negligible parameters—just 2 × 4096 = 8,192 per layer, roughly 0.004% of the layer's total.
Summing it up: Each transformer block contains about 202 million parameters (67M attention + 135M FFN + negligible normalization). With 32 layers, that's 6.5 billion parameters in the transformer stack. Add the 131 million for embeddings and a few thousand for the final layer norm, and we reach approximately 6.7 billion total parameters.
The critical insight here is that approximately 67% of the model's parameters live in feed-forward networks, while only 33% are in attention layers. This has profound implications:
First, FFN optimization matters enormously. Techniques like Mixture of Experts (MoE), which route different inputs to different FFN "experts," can dramatically increase model capacity without proportionally increasing compute, precisely because FFN is where most parameters live.
Second, attention efficiency improvements have limits. Techniques like FlashAttention and Multi-Query Attention are valuable, but even perfect attention optimization can only address about a third of the model's parameters.
Third, scaling dynamics change with architecture. As models grow larger, the relative importance of different components can shift. Very large models may spend more relative compute on attention due to its quadratic complexity with sequence length.
┌─────────────────────────────────────────────────────────────────────────┐
│ PARAMETER DISTRIBUTION IN LLAMA 2 7B │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ MODEL CONFIGURATION: │
│ d_model = 4096, n_layers = 32, n_heads = 32, vocab_size = 32000 │
│ intermediate_size = 11008 (FFN hidden dimension) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EMBEDDINGS: │
│ Token embeddings: vocab_size × d_model = 32000 × 4096 = 131M params │
│ (Tied with output projection, counted once) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PER TRANSFORMER BLOCK: │
│ │
│ Attention (4 matrices, each d_model × d_model): │
│ • W_Q: 4096 × 4096 = 16.8M │
│ • W_K: 4096 × 4096 = 16.8M │
│ • W_V: 4096 × 4096 = 16.8M │
│ • W_O: 4096 × 4096 = 16.8M │
│ Attention total: 67.1M params per layer │
│ │
│ Feed-Forward Network (SwiGLU, 3 matrices): │
│ • W_gate: 4096 × 11008 = 45.1M │
│ • W_up: 4096 × 11008 = 45.1M │
│ • W_down: 11008 × 4096 = 45.1M │
│ FFN total: 135.3M params per layer │
│ │
│ Layer Norms: 2 × 4096 = 8K params (negligible) │
│ │
│ Total per block: ~202M params │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TOTAL MODEL: │
│ Embeddings: 131M ( 2%) │
│ 32 Transformer blocks: 6.5B (98%) │
│ - Attention: 2.1B (33% of total) │
│ - FFN: 4.3B (67% of total) │
│ │
│ Total: ~6.7B parameters │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ KEY INSIGHT: │
│ FFN contains ~2× more parameters than attention! │
│ This is why FFN optimization (SwiGLU, MoE) matters so much. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part II: The Embedding Layer
Token Embeddings: From Discrete to Continuous
The embedding layer performs one of the most fundamental operations in neural language modeling: converting discrete token IDs into continuous vector representations. This transformation is essential because neural networks operate on continuous values—they perform multiplications, additions, and differentiable functions that require real numbers, not discrete symbols.
To understand why this matters, consider how a traditional computer represents text. A word like "cat" might be stored as ASCII codes [99, 97, 116], where each letter has an arbitrary numeric value. These numbers carry no semantic meaning—99 represents 'c' only by convention, and the fact that 99 > 97 tells us nothing about the relationship between 'c' and 'a'. A neural network operating on these raw numbers would need to discover, from scratch, that the sequence [99, 97, 116] represents a small furry animal.
Embeddings solve this problem by learning meaningful representations. Instead of arbitrary codes, each token gets a dense vector—typically between 768 and 12,288 dimensions depending on the model. These vectors are initialized randomly but adjusted during training so that tokens appearing in similar contexts develop similar vectors. After training on billions of text samples, "cat" and "dog" end up with similar embeddings (both are common pets, appearing in similar sentence structures), while "cat" and "economics" end up far apart.
The embedding lookup itself is remarkably simple: we maintain a matrix of shape (vocabulary_size × embedding_dimension), and to embed a token, we simply look up the row corresponding to its ID. For a vocabulary of 32,000 tokens and embedding dimension of 4096, this means a 131-million parameter matrix where row i contains the learned representation for token i.
This simplicity belies the embedding's importance. The embedding layer is where the model's understanding of language begins. All downstream processing—attention, feed-forward transformations, output predictions—operates on these learned representations. Poor embeddings mean the model starts with a handicap; good embeddings provide a strong foundation.
One subtle detail affects embedding quality: scaling. The raw embedding vectors typically have elements drawn from a roughly standard normal distribution (mean 0, variance 1). But as the embedding dimension increases, the vector magnitude stays roughly constant while downstream operations expect larger magnitudes. Many implementations multiply embeddings by the square root of the embedding dimension to compensate. This ensures that the contribution of embeddings to downstream computations scales appropriately regardless of the model size.
The embedding layer also reveals an interesting property of neural language models: they treat all tokens as fundamentally similar entities. Whether a token represents a common word like "the," a rare technical term, a punctuation mark, or even a number, it gets the same treatment—lookup in the embedding matrix, yielding a vector of the same dimension. The model must learn to handle all these different token types through the same computational path, which partly explains why LLMs need such enormous scale to handle language's full diversity.
┌─────────────────────────────────────────────────────────────────────────┐
│ TOKEN EMBEDDINGS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE FUNDAMENTAL TRANSFORMATION: │
│ ─────────────────────────────── │
│ │
│ Raw text contains no inherent numeric meaning: │
│ "cat" → ASCII [99, 97, 116] → arbitrary, meaningless numbers │
│ │
│ Embeddings create meaningful numeric representations: │
│ "cat" → token_id=3797 → lookup → [0.12, -0.34, 0.56, ..., 0.78] │
│ (4096 dimensions) │
│ │
│ These vectors LEARN semantic meaning during training: │
│ • "cat" and "dog" → similar vectors (similar contexts) │
│ • "cat" and "economics" → distant vectors │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE EMBEDDING MATRIX: │
│ ───────────────────── │
│ │
│ Shape: (vocab_size × d_model) = (32000 × 4096) │
│ │
│ Each row i contains the learned representation for token i: │
│ │
│ Row 0: [0.02, -0.15, 0.08, ..., 0.21] ← embedding for token 0 │
│ Row 1: [0.11, 0.03, -0.19, ..., 0.45] ← embedding for token 1 │
│ ... │
│ Row 3797: [0.12, -0.34, 0.56, ..., 0.78] ← embedding for "cat" │
│ ... │
│ Row 31999: [-0.07, 0.22, 0.01, ..., 0.33] ← embedding for last token │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY EMBEDDINGS CAPTURE MEANING: │
│ ─────────────────────────────── │
│ │
│ During training, gradients adjust embeddings so that: │
│ │
│ 1. Tokens in SIMILAR CONTEXTS get SIMILAR vectors │
│ "The [cat] sat on the mat" ↔ "The [dog] sat on the rug" │
│ Both cat and dog fill the same grammatical slot with similar │
│ surrounding words → their embeddings converge │
│ │
│ 2. The famous "king - man + woman ≈ queen" relationship emerges │
│ because gender and royalty are separate directions in the space │
│ │
│ 3. Semantic relationships become geometric relationships │
│ Similarity = cosine distance in embedding space │
│ Analogy = vector arithmetic │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EMBEDDING DIMENSION BY MODEL: │
│ ───────────────────────────── │
│ │
│ Model d_model Vocab Size Embedding Params │
│ ────────────────────────────────────────────────────────── │
│ GPT-2 Small 768 50,257 38.6M │
│ GPT-2 Large 1,280 50,257 64.3M │
│ BERT-base 768 30,522 23.4M │
│ Llama 7B 4,096 32,000 131M │
│ Llama 70B 8,192 32,000 262M │
│ GPT-4 (est.) ~12,288 ~100,000 ~1.2B │
│ │
│ Rule: Larger models use higher dimensions to capture more nuance. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Embedding Scaling: A Subtle but Important Detail
A subtle detail that many explanations overlook is embedding scaling. When you initialize an embedding matrix, each element typically comes from a distribution with variance around 1 (like a standard normal or uniform distribution). This means the magnitude of embedding vectors doesn't change much as you increase the embedding dimension—a 768-dimensional vector has roughly the same L2 norm as a 4096-dimensional vector if each element has unit variance.
This creates a problem for downstream operations. Many components of the transformer, particularly attention, involve adding embeddings to other quantities (like positional encodings) or computing dot products between vectors. For these operations to work consistently across different model sizes, the embeddings should scale with dimension.
The solution is simple: multiply embeddings by the square root of the embedding dimension. This scaling ensures that:
- The variance of the embedding output scales linearly with dimension
- Embeddings and positional encodings contribute proportionally regardless of model size
- The model behaves more consistently across scales
Most implementations apply this scaling implicitly, either in the embedding lookup or immediately after. It's a small detail, but getting it wrong can cause training instabilities, especially when working with very large or very small models.
Part III: Positional Information
The Fundamental Problem: Transformers Can't Count
One of the most counterintuitive aspects of the transformer architecture is that it has no built-in notion of sequence order. When you process a sequence like "The cat sat on the mat" through a transformer, the architecture treats each token identically regardless of its position. Without intervention, the model would produce the exact same representations for "The cat sat on the mat" and "mat the on sat cat The"—both are just sets of tokens as far as the raw attention mechanism is concerned.
This property, called permutation equivariance, arises directly from how attention works. The attention operation computes relationships between all pairs of tokens using only their content (their embedding vectors). Position doesn't enter the computation anywhere. If you shuffle the input tokens, attention simply shuffles the outputs in the same way.
Why does this happen? Consider the attention formula: each token's query vector is compared against all other tokens' key vectors through dot products. Nothing in this computation references where any token sits in the sequence. Position 1 and position 100 are computed identically; only their content differs.
This is a feature for some purposes—it's why transformers can parallelize across sequence positions during training. But for language modeling, it's a critical bug. "Dog bites man" and "Man bites dog" have completely different meanings despite containing the same words. Code depends critically on ordering: x = y; z = x is very different from z = x; x = y. The model needs position information to function.
The solution is to explicitly inject position information into the model. There are several approaches, each with tradeoffs:
Absolute positional embeddings treat positions like tokens and learn an embedding for each one. Position 0 gets embedding P[0], position 1 gets P[1], and so on. These position embeddings are added to the token embeddings before entering the transformer. The approach is simple and works well within the training context length, but it has a fundamental limitation: positions beyond the maximum trained length have no embeddings. A model trained on 2048 positions simply doesn't know what position 2049 means.
Sinusoidal encodings, from the original transformer paper, use mathematical functions (sines and cosines of different frequencies) to generate position representations. These don't require learning and can, in principle, extrapolate to longer sequences. However, practical experience showed that extrapolation quality was poor—the model's behavior degraded significantly beyond training lengths.
Rotary Position Embeddings (RoPE) take a completely different approach. Instead of adding position information to the embeddings, RoPE encodes position through rotations applied to the query and key vectors in attention. The key insight is that when you compute the dot product of two rotated vectors, the result depends on the difference in their rotation angles. If token at position m is rotated by angle θm and token at position n is rotated by angle θn, their dot product depends on (θm - θn)—the relative position, not the absolute positions.
This relative position encoding is more robust and aligns with linguistic intuition. When predicting the next word, what matters is that "the" appeared three tokens ago, not that it appeared at position 47. RoPE naturally encodes this relativity while still allowing extrapolation through techniques like NTK-aware scaling and YaRN.
ALiBi (Attention with Linear Biases) is even simpler: don't modify the embeddings at all. Instead, add a distance-based penalty directly to attention scores. Tokens far apart have their attention scores reduced by an amount proportional to their distance. This implements a simple but effective inductive bias: nearby tokens are more relevant than distant ones. Different attention heads use different penalty slopes, allowing some heads to focus locally while others attend more globally.
Modern LLMs have converged on RoPE as the dominant approach, used by Llama, Llama 2, Llama 3, Mistral, Qwen, Phi, and many others. Its combination of relative position awareness, good extrapolation properties (with scaling tricks), and efficiency has made it the standard choice.
┌─────────────────────────────────────────────────────────────────────────┐
│ THE POSITION PROBLEM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ WHY POSITION MATTERS: │
│ ───────────────────── │
│ │
│ Without position information: │
│ "Dog bites man" ≡ "Man bites dog" (same tokens, same output!) │
│ │
│ The attention mechanism sees only CONTENT, not POSITION: │
│ Q[i] · K[j] depends on what tokens i and j contain │
│ Nothing in this computation knows i comes before j │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ APPROACH 1: LEARNED ABSOLUTE POSITIONS │
│ ─────────────────────────────────────── │
│ │
│ Treat positions like vocabulary tokens—learn an embedding for each: │
│ │
│ Position 0 → P[0] = [0.1, 0.2, ...] │
│ Position 1 → P[1] = [0.3, -0.1, ...] │
│ Position 2 → P[2] = [0.0, 0.4, ...] │
│ ... │
│ Position 2047 → P[2047] = [...] │
│ │
│ final_embedding = token_embedding + position_embedding │
│ │
│ ✓ Simple and effective within trained context │
│ ✗ Cannot extrapolate: position 2048 has no embedding! │
│ Used by: GPT-2, BERT │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ APPROACH 2: ROPE (Rotary Position Embeddings) │
│ ────────────────────────────────────────────── │
│ │
│ Key insight: Encode position as ROTATIONS, not additions. │
│ │
│ Query and Key vectors are ROTATED by their positions: │
│ q_m = rotate(q, angle=m×θ) │
│ k_n = rotate(k, angle=n×θ) │
│ │
│ When computing attention score q_m · k_n: │
│ The dot product depends on the DIFFERENCE (m - n)! │
│ │
│ This means: │
│ • Attention depends on RELATIVE position, not absolute │
│ • "3 tokens apart" encodes the same regardless of where in sequence │
│ • Can extrapolate with scaling tricks (NTK, YaRN) │
│ │
│ ✓ Relative position naturally encoded │
│ ✓ Better extrapolation to longer sequences │
│ ✓ No extra parameters (rotation angles computed, not learned) │
│ Used by: Llama, Llama 2, Llama 3, Mistral, Qwen, Phi │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ APPROACH 3: ALIBI (Attention with Linear Biases) │
│ ───────────────────────────────────────────────── │
│ │
│ Don't modify embeddings—add distance penalty to attention scores: │
│ │
│ attention_score[i,j] -= m × |i - j| │
│ │
│ Close tokens: small penalty (high attention allowed) │
│ Far tokens: large penalty (attention suppressed) │
│ │
│ Different heads use different slopes m: │
│ Head 1: m = 0.5 (attends far) │
│ Head 8: m = 4.0 (attends only nearby) │
│ │
│ ✓ No learned parameters │
│ ✓ Excellent extrapolation │
│ ✗ Strong inductive bias (may not suit all tasks) │
│ Used by: BLOOM, MPT, Falcon │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part IV: The Transformer Block
The Core Repeating Unit
The transformer block is the fundamental building block of modern LLMs. A 7B parameter model typically has 32 of these blocks stacked sequentially; a 70B model might have 80. Each block performs the same operations with different learned parameters, progressively refining the representations as information flows through the network.
Each block contains two main sub-components: a multi-head self-attention layer and a feed-forward network (FFN). These are connected through residual connections and separated by layer normalization. The specific arrangement of these elements—particularly where normalization occurs—has been a subject of significant research and has evolved since the original transformer.
Understanding the transformer block requires understanding both what each component does and why they're arranged in this particular way. The attention layer allows positions to exchange information with each other—it's where context is built. The FFN processes each position independently, applying the same learned transformation to all positions—it's where this contextual information is processed and refined. The residual connections ensure information can flow easily through many stacked blocks. The normalization keeps activations in a stable range, preventing training from diverging.
Let's examine each component in detail, starting with the crucial question of where to place normalization.
Pre-Norm vs Post-Norm: Why Modern LLMs Use Pre-Norm
The placement of layer normalization might seem like a minor implementation detail, but it profoundly affects training dynamics. The original transformer paper placed normalization after each sub-layer ("post-norm"), but virtually every modern LLM uses normalization before each sub-layer ("pre-norm"). Understanding why requires diving into gradient flow.
In the post-norm configuration, the computation looks like:
The residual connection adds the sublayer output to the input, and then normalization is applied to the sum. This seems natural—normalize the combined result—but creates problems during training.
The issue is gradient flow. During backpropagation, gradients must flow from the loss through every layer back to the earliest parameters. In post-norm, every gradient path passes through every LayerNorm operation. LayerNorm involves computing means and variances across dimensions, then dividing and scaling. The gradient through these operations is complex and, critically, can attenuate signals. With 32 layers, gradients must pass through 64 LayerNorm operations (one for attention and one for FFN in each block). The compound effect of many normalization layers in the gradient path can make optimization difficult, especially early in training when the network hasn't yet found a reasonable operating point.
Pre-norm rearranges to:
Now normalization happens inside the sublayer, before the residual addition. This creates a "clean" residual path: gradients flowing through the residual connection don't pass through any normalization. The gradient from layer output to layer input is simply 1 plus the gradient through the sublayer. That "1" is crucial—it guarantees that some gradient signal reaches every layer, regardless of what the sublayer does.
The practical impact is dramatic. Pre-norm transformers:
- Train stably without careful learning rate warmup
- Tolerate larger learning rates
- Show more consistent loss curves
- Rarely experience training collapse or divergence
Post-norm transformers:
- Require careful learning rate warmup schedules
- Are sensitive to initialization
- Can diverge during training, especially for deep networks
- May achieve slightly better final performance when training succeeds
The slight potential performance advantage of post-norm isn't worth the training difficulties for most applications. Modern LLMs universally use pre-norm. The architecture prioritizes reliable training at scale over marginal quality improvements that might not materialize anyway.
┌─────────────────────────────────────────────────────────────────────────┐
│ PRE-NORM VS POST-NORM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ POST-NORM (Original Transformer, 2017): │
│ ──────────────────────────────────────── │
│ │
│ input ─────────────────────────┐ │
│ │ │ │
│ ▼ │ │
│ ┌─────────┐ │ │
│ │ Sublayer│ (attention or FFN) │ │
│ └─────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ADD ◄─────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │LayerNorm│ ← Normalize AFTER residual │
│ └─────────┘ │
│ │ │
│ ▼ │
│ output │
│ │
│ GRADIENT PROBLEM: │
│ Every gradient must pass through every LayerNorm. │
│ 64 LayerNorms in a 32-layer model compound to attenuate signals. │
│ Training becomes unstable, especially early on. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PRE-NORM (GPT-2 onwards, ALL modern LLMs): │
│ ─────────────────────────────────────────── │
│ │
│ input ─────────────────────────┐ │
│ │ │ │
│ ▼ │ │
│ ┌─────────┐ │ │
│ │LayerNorm│ ← Normalize BEFORE sublayer │
│ └─────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────┐ │ │
│ │ Sublayer│ (attention or FFN) │ │
│ └─────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ADD ◄─────────────────────────┘ │
│ │ │
│ ▼ │
│ output │
│ │
│ GRADIENT ADVANTAGE: │
│ The residual path is "clean"—gradient flows as: │
│ ∂output/∂input = 1 + ∂Sublayer/∂input │
│ │
│ That "1" guarantees gradient reaches every layer! │
│ No compounding attenuation from normalization layers. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PRACTICAL IMPACT: │
│ │
│ Pre-Norm: │
│ ✓ Trains stably without learning rate warmup │
│ ✓ Tolerates larger learning rates │
│ ✓ Rarely diverges │
│ ✓ Consistent loss curves │
│ │
│ Post-Norm: │
│ ✗ Requires careful warmup schedules │
│ ✗ Sensitive to initialization │
│ ✗ Can diverge, especially for deep nets │
│ ○ May achieve slightly better final performance (when it works) │
│ │
│ Conclusion: Pre-Norm is universally preferred for its stability. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Residual Connections: The Secret to Deep Networks
Residual connections are perhaps the single most important architectural innovation enabling deep neural networks. Without them, training a 32-layer transformer would be extremely difficult; training a 96-layer model would be nearly impossible. Understanding why requires understanding the fundamental challenge of training deep networks.
Consider a simple neural network without residual connections. Each layer transforms its input: layer_output = f(layer_input). When we backpropagate gradients through this network, we compute the chain rule across all layers:
Each term ∂layer_i/∂layer_{i-1} is typically less than 1 (the derivative of a bounded function is bounded). When you multiply 31 terms that are all less than 1, the product becomes vanishingly small. This is the vanishing gradient problem: gradients shrink exponentially as they propagate backward, so early layers receive almost no learning signal.
The opposite can also occur: if gradients are greater than 1, the product explodes. This exploding gradient problem causes unstable training with wildly oscillating or diverging losses.
Residual connections provide an elegant solution. Instead of layer_output = f(layer_input), we use:
Now the gradient through the residual connection is:
That "+1" is transformative. Even if ∂f/∂layer_input is very small (or even zero), the gradient through this layer is at least 1. Gradients can flow directly from later layers to earlier layers through the residual path without attenuation.
Think of it as a highway system. Without residuals, information must take local roads through every layer, potentially getting stuck or lost. Residual connections add a highway that bypasses the layers entirely. Information can flow on the highway (unchanged through residual) or take exits through layers (transformed by sublayers). The network learns which path is useful for which information.
This leads to an interesting perspective: a residual network is actually an ensemble of paths. Consider a 3-layer residual network. Information can:
- Skip all layers (pure residual): input → output
- Use only layer 1: input → layer1 → output
- Use only layer 2: input → layer2 → output
- Use layers 1 and 2: input → layer1 → layer2 → output
- ...and so on
With n residual blocks, there are 2^n possible paths. The network effectively learns which combination of paths is useful for each input. This "ensemble" view helps explain why residual networks are so powerful and robust.
In transformers, each block has two residual connections: one around attention and one around the FFN. A 32-layer transformer thus has 64 residual connections, creating 2^64 possible paths—far more than could ever be explicitly enumerated, but the gradient flow through all of them keeps training stable.
┌─────────────────────────────────────────────────────────────────────────┐
│ RESIDUAL CONNECTIONS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE DEEP NETWORK PROBLEM: │
│ ────────────────────────── │
│ │
│ Without residuals, gradients must chain through every layer: │
│ │
│ ∂Loss/∂layer1 = ∂Loss/∂layer32 × ∂layer32/∂layer31 × ... × ∂layer2/∂1 │
│ │
│ If each factor ≈ 0.9, then: 0.9^31 ≈ 0.04 │
│ If each factor ≈ 0.5, then: 0.5^31 ≈ 0.0000000005 │
│ │
│ Gradients VANISH exponentially → early layers don't learn! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE RESIDUAL SOLUTION: │
│ ─────────────────────── │
│ │
│ Instead of: output = f(input) │
│ Use: output = input + f(input) │
│ │
│ input ─────────────────────┐ │
│ │ │ │
│ ▼ │ │
│ ┌─────────┐ │ │
│ │ f(·) │ │ (skip connection) │
│ └─────────┘ │ │
│ │ │ │
│ └───────► + ◄────────────┘ │
│ │ │
│ ▼ │
│ output │
│ │
│ Now the gradient is: │
│ ∂output/∂input = 1 + ∂f(input)/∂input │
│ ↑ │
│ This "1" guarantees gradient ≥ 1! │
│ │
│ Gradients can ALWAYS flow through the residual path. │
│ Early layers receive strong learning signals even in deep networks. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE ENSEMBLE PERSPECTIVE: │
│ ───────────────────────── │
│ │
│ With 3 residual blocks, information can take these paths: │
│ │
│ • input → output (skip all) │
│ • input → block1 → output (use block 1 only) │
│ • input → block2 → output (use block 2 only) │
│ • input → block3 → output (use block 3 only) │
│ • input → block1 → block2 → output (use 1 and 2) │
│ • input → block1 → block3 → output (use 1 and 3) │
│ • input → block2 → block3 → output (use 2 and 3) │
│ • input → block1 → block2 → block3 → output (use all) │
│ │
│ 2^3 = 8 paths! The network learns which to use for each input. │
│ │
│ With n=64 residuals in a 32-layer transformer: │
│ 2^64 ≈ 18 quintillion possible paths │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ IN TRANSFORMERS: │
│ ──────────────── │
│ │
│ Each block has TWO residual connections: │
│ │
│ x → [LayerNorm → Attention] → ADD ← x (residual around attention) │
│ │ │
│ ▼ │
│ x → [LayerNorm → FFN] → ADD ← x (residual around FFN) │
│ │ │
│ ▼ │
│ next block │
│ │
│ This double-residual structure is critical for deep transformers. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part V: Layer Normalization
What Problem Does Normalization Solve?
Layer normalization addresses a fundamental challenge in training deep neural networks: the internal covariate shift problem. As training progresses, the distribution of activations within the network changes. Each layer receives inputs whose distribution shifts as earlier layers update their weights. This creates a moving target problem—each layer must constantly adapt to changing input distributions while also trying to learn useful transformations.
Consider what happens without normalization. In one training batch, a particular layer might see activations with mean 0.5 and standard deviation 2.0. In the next batch, those statistics might be 0.8 and 3.5. The layer's weights were tuned for the first distribution, so they may not work well for the second. This constant adaptation slows learning and can cause training instability.
Normalization stabilizes training by ensuring that each layer sees inputs with consistent statistics. Layer normalization specifically normalizes across the feature dimension (the hidden dimension in transformers), computing mean and variance statistics independently for each position in the sequence and each example in the batch.
The normalization process has four steps:
- Compute the mean across the hidden dimension: μ = (1/d) Σᵢ xᵢ
- Compute the variance: σ² = (1/d) Σᵢ (xᵢ - μ)²
- Normalize: x̂ᵢ = (xᵢ - μ) / √(σ² + ε)
- Scale and shift: yᵢ = γᵢ × x̂ᵢ + βᵢ
The first three steps normalize the input to have zero mean and unit variance. The fourth step applies learned parameters γ (gain) and β (bias) that allow the network to undo or modify the normalization if that helps learning. These parameters give the network flexibility—it can learn to use the normalized activations directly, or it can learn to shift them back toward a different distribution if that's more useful.
The small constant ε (typically 1e-5 or 1e-6) prevents division by zero when variance is very small.
Why Layer Norm Instead of Batch Norm?
The original breakthrough in normalization was Batch Normalization (BatchNorm), which normalizes across the batch dimension rather than the feature dimension. BatchNorm computes statistics like "what is the mean activation for neuron 37 across all examples in this batch?" This works extremely well for computer vision, where batches of images are relatively homogeneous.
For transformers processing text, BatchNorm has serious problems:
Variable sequence lengths: Different sequences in a batch have different lengths. Position 50 might exist in some sequences but not others. Normalizing across batch at position 50 would mix examples that have very different contexts at that position.
Small batch sizes: At inference time, we often process single examples (batch size 1). BatchNorm requires a batch to compute meaningful statistics—it falls back to running averages at inference, but this creates a train/inference mismatch.
Train/test mismatch: BatchNorm maintains running statistics during training and uses them at inference. This creates subtle differences between training and inference behavior that can cause problems.
Layer normalization avoids all these issues by normalizing within each example independently. Each position in each example is normalized using only its own hidden dimension values. This means:
- Variable lengths are fine—each position is handled independently
- Batch size 1 works identically to batch size 1000
- No running statistics—identical computation at training and inference
The tradeoff is that LayerNorm doesn't normalize across the batch, so it doesn't provide the batch-level regularization effect of BatchNorm. However, transformers have plenty of other regularization (dropout, large datasets, etc.), so this isn't a significant loss.
RMSNorm: The Simpler, Faster Alternative
Modern LLMs increasingly use RMSNorm (Root Mean Square Normalization) instead of full LayerNorm. The key insight is that mean subtraction may not be necessary—the primary benefit of normalization is controlling scale, not centering.
RMSNorm simplifies the computation:
This is just LayerNorm without the mean subtraction step (and without the bias parameter β). The normalization divides by the root-mean-square of the input, which controls the scale without shifting the center.
Why does this work? Research and empirical evidence suggest that the main benefit of normalization comes from controlling the magnitude of activations. The centering (mean subtraction) provides some benefit, but much less than the scaling. By removing centering, RMSNorm:
- Requires fewer operations (no mean computation or subtraction)
- Has fewer parameters (no β bias vector)
- Is approximately 10-30% faster
- Achieves essentially identical model quality
The speed improvement might seem small, but normalization is applied many times per forward pass (twice per transformer block, so 64 times in a 32-layer model). Those 10-30% savings compound across all those applications.
RMSNorm is now the standard choice. Llama, Llama 2, Llama 3, Mistral, Qwen, and most other recent open models all use RMSNorm. There's no reason to use full LayerNorm for new models.
┌─────────────────────────────────────────────────────────────────────────┐
│ LAYER NORMALIZATION VS RMSNORM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE PROBLEM NORMALIZATION SOLVES: │
│ ───────────────────────────────── │
│ │
│ Without normalization, activation magnitudes can vary wildly: │
│ • Different layers produce different scales │
│ • Different inputs produce different scales │
│ • Scales drift during training as weights update │
│ │
│ This makes optimization unstable—gradients depend on scale, │
│ and changing scales mean the model constantly chases a moving target. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ LAYER NORM: │
│ ─────────── │
│ │
│ For each position, normalize across hidden dimension: │
│ │
│ Input: x = [x₁, x₂, ..., x_d] (one position's hidden state) │
│ │
│ 1. μ = mean(x) = (1/d) Σᵢ xᵢ │
│ 2. σ² = var(x) = (1/d) Σᵢ (xᵢ - μ)² │
│ 3. x̂ᵢ = (xᵢ - μ) / √(σ² + ε) (normalize) │
│ 4. yᵢ = γᵢ × x̂ᵢ + βᵢ (scale and shift) │
│ │
│ Output has mean≈0, variance≈1, then scaled by learned γ, β. │
│ │
│ Parameters: γ ∈ R^d (gain), β ∈ R^d (bias) │
│ Total: 2d parameters per LayerNorm │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RMSNORM: │
│ ──────── │
│ │
│ Key insight: Mean subtraction may not be necessary! │
│ The main benefit is controlling SCALE, not centering. │
│ │
│ Simplified formula: │
│ y = γ × x / √(mean(x²) + ε) │
│ │
│ No mean subtraction, no β bias. │
│ Just divide by root-mean-square, then scale. │
│ │
│ Parameters: γ ∈ R^d only │
│ Total: d parameters per RMSNorm (half of LayerNorm!) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY RMSNORM WORKS: │
│ ────────────────── │
│ │
│ Empirical finding: Centering provides marginal benefit. │
│ The critical operation is scale normalization. │
│ │
│ Benefits of RMSNorm: │
│ • ~10-30% faster (fewer operations) │
│ • Fewer parameters (no bias) │
│ • Same model quality as LayerNorm │
│ │
│ Applied 64 times per forward pass in a 32-layer model, │
│ those 10-30% savings add up significantly. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ USAGE: │
│ │
│ LayerNorm: GPT-2, GPT-3, BERT, older models │
│ RMSNorm: Llama, Llama 2, Llama 3, Mistral, Qwen, Phi │
│ │
│ RMSNorm is now the standard. No reason to use LayerNorm in new work. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VI: The Feed-Forward Network
The Role of FFN: Processing, Not Communicating
While attention allows positions to exchange information, the feed-forward network (FFN) processes that information independently at each position. This division of labor is fundamental to how transformers work:
- Attention: "What information from other positions is relevant to me?"
- FFN: "Given all this gathered information, what should I do with it?"
The FFN applies the same transformation to every position in parallel. It doesn't see or care about other positions—each position is processed as if it were alone. This position-independent processing might seem limiting, but it's actually what makes the FFN powerful. The FFN can learn complex, nonlinear transformations without worrying about sequence structure. That structure is handled by attention; the FFN focuses on transformation.
The basic FFN structure is simple: project to a larger dimension, apply a nonlinearity, project back. This "expand-activate-contract" pattern is common in neural networks. The expansion creates space for the network to learn more complex transformations than would be possible in the original dimension. The nonlinearity introduces the crucial element of, well, nonlinearity—without it, stacking layers would be pointless because a stack of linear transformations is just another linear transformation. The contraction brings the representation back to the model dimension for the next layer.
The expansion ratio is important. The original transformer used an intermediate dimension 4× the model dimension: for d_model=512, the FFN expanded to d_ff=2048. Modern models often use ratios around 2.67× to 4×, balancing capacity against parameter efficiency.
Why FFN Has More Parameters Than Attention
It's worth pausing on a counterintuitive fact: the FFN contains roughly twice as many parameters as attention in each transformer block. Attention is the signature innovation of the transformer, the subject of countless papers and optimizations, yet the humble FFN is the parameter-heavy component.
The math is straightforward. For attention with d_model hidden dimension:
- W_Q: d_model × d_model
- W_K: d_model × d_model
- W_V: d_model × d_model
- W_O: d_model × d_model
- Total: 4 × d_model²
For the original FFN with 4× expansion:
- W_up: d_model × 4×d_model
- W_down: 4×d_model × d_model
- Total: 8 × d_model²
For SwiGLU FFN with 2.67× expansion (to match parameter count):
- W_gate: d_model × 2.67×d_model
- W_up: d_model × 2.67×d_model
- W_down: 2.67×d_model × d_model
- Total: 8 × d_model² (approximately)
In both cases, FFN has about twice the parameters of attention. This explains why FFN optimization techniques—Mixture of Experts (MoE), efficient FFN architectures, FFN pruning—are so impactful. Optimizing the larger component has more leverage.
SwiGLU: The Modern FFN Activation
The original transformer FFN used ReLU activation: FFN(x) = W₂ · ReLU(W₁ · x). ReLU is simple and effective but has limitations. It completely zeroes out negative values, which loses information. It has a non-differentiable point at zero, which can cause gradient issues. And it can suffer from "dying ReLU"—neurons that get stuck at zero and never recover.
Modern LLMs almost universally use SwiGLU (or similar gated activations). The key innovation is the gating mechanism: instead of a simple activation function, the network learns which parts of the transformation to keep and which to suppress.
SwiGLU works as follows:
- Compute a "gate" value: gate = Swish(x × W_gate)
- Compute an "up" projection: up = x × W_up
- Multiply element-wise: hidden = gate ⊙ up
- Project down: output = hidden × W_down
The Swish activation (also called SiLU) is x × σ(x), where σ is the sigmoid function. It's smooth everywhere (unlike ReLU) and has a slight negative region that can help express certain functions.
The gating mechanism is the crucial element. Instead of applying the same transformation everywhere, the gate learns to selectively enable or suppress different dimensions of the transformation. Think of it as a learned attention mechanism over the FFN dimensions—which features should be active for this particular input?
This gating significantly improves model quality. Research from Google showed SwiGLU improving perplexity by about 5% compared to ReLU FFNs at the same parameter count. Given that LLM training is extremely expensive, a 5% improvement for free (same parameters, similar compute) is huge.
The cost is additional complexity and one more weight matrix. SwiGLU has three weight matrices instead of two, which would increase parameters by 50%. To compensate, the intermediate dimension is reduced from 4× to about 2.67× (specifically 8/3), keeping total parameters constant while gaining the quality benefit.
┌─────────────────────────────────────────────────────────────────────────┐
│ FEED-FORWARD NETWORK │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE FFN'S ROLE: │
│ ─────────────── │
│ │
│ While attention allows positions to COMMUNICATE: │
│ "What information from other tokens do I need?" │
│ │
│ FFN allows positions to PROCESS independently: │
│ "Given the information I've gathered, what do I do with it?" │
│ │
│ Each position is processed identically and independently. │
│ The same weights transform position 1 and position 100. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ BASIC STRUCTURE (Expand → Activate → Contract): │
│ ─────────────────────────────────────────────── │
│ │
│ x (d_model = 4096) │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ W_up │ Project to larger dimension │
│ └─────────┘ │
│ │ │
│ │ (d_ff = 11008 or 4 × d_model) │
│ ▼ │
│ ┌─────────┐ │
│ │ Activate │ Nonlinearity (ReLU, Swish, etc.) │
│ └─────────┘ │
│ │ │
│ │ (still d_ff) │
│ ▼ │
│ ┌─────────┐ │
│ │ W_down │ Project back to model dimension │
│ └─────────┘ │
│ │ │
│ │ (d_model = 4096) │
│ ▼ │
│ output │
│ │
│ WHY EXPAND? │
│ The larger intermediate dimension provides "working space" for │
│ complex transformations. Each dimension can learn to detect a │
│ specific pattern or feature. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SWIGLU (Modern LLM Standard): │
│ ────────────────────────────── │
│ │
│ Innovation: GATED activation—learn which features to keep. │
│ │
│ x (d_model) │
│ │ │
│ ┌───────┴───────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ │
│ │ W_gate │ │ W_up │ │
│ └─────────┘ └─────────┘ │
│ │ │ │
│ │ (d_ff) │ (d_ff) │
│ ▼ │ │
│ ┌─────────┐ │ │
│ │ Swish │ │ │
│ └─────────┘ │ │
│ │ │ │
│ └───────⊙───────┘ (element-wise multiply) │
│ │ │
│ │ (d_ff) │
│ ▼ │
│ ┌─────────┐ │
│ │ W_down │ │
│ └─────────┘ │
│ │ │
│ │ (d_model) │
│ ▼ │
│ output │
│ │
│ The gate (Swish(x × W_gate)) learns to selectively enable/suppress │
│ different dimensions of the transformation. This is like an internal │
│ attention mechanism over FFN features. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY SWIGLU IS BETTER: │
│ ───────────────────── │
│ │
│ 1. SMOOTH: Swish is differentiable everywhere (unlike ReLU at 0) │
│ 2. GATED: Learns what to pass through, not just thresholding │
│ 3. NON-MONOTONIC: Swish has slight negative region for flexibility │
│ │
│ Result: ~5% perplexity improvement vs ReLU FFN at same params. │
│ For expensive LLM training, this free improvement is significant. │
│ │
│ Used by: Llama, Llama 2, Llama 3, Mistral, Qwen, and most modern LLMs│
│ │
└─────────────────────────────────────────────────────────────────────────┘
FFN as Learned Key-Value Memory
An intriguing perspective on the FFN comes from viewing it as a learned key-value memory. This isn't just a metaphor—it reveals something fundamental about what FFN learns and how it stores knowledge.
Consider the basic FFN: output = W₂ · ReLU(W₁ · x). We can reinterpret this as:
- Keys: Each row of W₁ is a "key" pattern that the FFN looks for
- Matching: W₁ · x computes how well the input x matches each key
- Selection: ReLU selects which keys "match" (positive activations)
- Values: W₂ columns are "values" associated with matching keys
- Retrieval: The output is a weighted sum of values for matching keys
From this view, the FFN's intermediate dimension represents the number of key-value pairs. A 4096 → 11008 → 4096 FFN has 11,008 key-value pairs per layer. With 32 layers, that's over 350,000 key-value pairs in the model.
This perspective explains several phenomena:
Why larger FFNs know more facts: More parameters mean more key-value pairs, which means more facts can be stored. The intermediate dimension directly controls memory capacity.
Why knowledge is distributed across layers: Different layers store different types of knowledge. Early layers might store syntactic patterns; later layers might store factual associations. Each layer's FFN is a separate memory bank.
How model editing works: Recent research on editing factual knowledge in LLMs (e.g., changing "The Eiffel Tower is in Paris" to "The Eiffel Tower is in London") often targets FFN weights. The key-value memory view explains why: we're updating the value associated with a particular key pattern.
Why Mixture of Experts works: MoE replaces the single FFN with multiple "expert" FFNs and a router that selects which experts to use. From the memory view, this massively increases the number of key-value pairs without proportionally increasing computation—different experts store different knowledge, and the router selects relevant experts for each input.
Part VII: The Output Layer
From Hidden States to Token Probabilities
After information has flowed through all transformer blocks, we have a sequence of hidden state vectors, one for each position in the input. The final task is converting these hidden states into predictions: for each position, what token is likely to come next?
This conversion happens in three stages:
Final Layer Normalization: One more normalization is applied to the hidden states emerging from the last transformer block. This ensures the values are in a consistent range before the final projection. Without this, the output logits could have inconsistent scales depending on what happened in the transformer stack.
Output Projection (LM Head): The hidden states are projected from the model dimension (e.g., 4096) to the vocabulary size (e.g., 32,000). This is a simple linear transformation: logits = hidden_state × W_out, where W_out has shape (d_model × vocab_size). Each of the 32,000 vocabulary entries gets a score (logit) representing how well the hidden state matches that token.
Softmax: The logits are converted to probabilities using the softmax function. Each position gets a probability distribution over the vocabulary, with probabilities summing to 1. Higher logits become higher probabilities exponentially—softmax amplifies differences.
The output projection deserves special attention. It's computing a score for every vocabulary token based on the hidden state. Intuitively, it's asking: "How similar is this hidden representation to each possible next token?" The token with highest similarity (highest logit) gets highest probability.
Weight Tying: An Elegant Parameter Reduction
One of the most elegant optimizations in transformer LMs is weight tying: using the same weights for the input embedding layer and the output projection layer.
At first, this seems odd. The embedding layer converts token IDs to vectors (lookup), while the output projection converts vectors to token scores (linear transformation). These seem like different operations.
But consider what they're doing semantically:
- Embedding: "Given this token, what's its vector representation?"
- Output: "Given this vector, which token does it represent?"
These are almost inverse operations! If the embedding for "cat" is some vector v, then seeing vector v at the output should strongly suggest "cat" as the prediction. Using the same weights enforces this symmetry: the embedding and output spaces are identical.
Mathematically, with weight tying:
- Embedding lookup: embedding[token_id] → gets row i of weight matrix
- Output projection: hidden_state × W^T → computes dot product with each row
The output for token i is hidden_state · embedding[i]—the dot product between the hidden state and that token's embedding. If the hidden state is "close" to a token's embedding, that token gets a high score. This makes perfect sense: generate tokens whose embeddings match the current representation.
The benefits are substantial:
-
Parameter reduction: For a 32,000-token vocabulary and 4096-dimensional embeddings, weight tying saves 131 million parameters—not trivial even for a 7B model.
-
Regularization: Forcing embeddings to be useful for both input and output constrains them to learn more general representations. This often improves generalization.
-
Semantic consistency: The model's "understanding" of a word (input embedding) is the same as what it "means" to generate that word (output space). There's a unified semantic space.
Weight tying is now standard practice. Nearly all modern LLMs use it. The approach is so successful that not tying weights is considered unusual and requires justification.
┌─────────────────────────────────────────────────────────────────────────┐
│ OUTPUT PROJECTION AND WEIGHT TYING │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE OUTPUT PIPELINE: │
│ ──────────────────── │
│ │
│ After transformer blocks, we have hidden states (one per position). │
│ Now convert these to token probabilities: │
│ │
│ Hidden state for position n │
│ ┌────────────────────────────────┐ │
│ │ h_n: [0.23, -0.17, ..., 0.42] │ (d_model = 4096 dimensions) │
│ └────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ RMSNorm │ Final normalization │
│ └─────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────┐ │
│ │ W_out (LM Head) │ │
│ │ (d_model × vocab_size) │ │
│ │ (4096 × 32000) │ │
│ └────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────┐ │
│ │ logits: [2.1, -0.5, 0.3, 5.2, ..., -1.2] (32000 scores) │ │
│ │ "the" "cat" "dog" "Paris" "xyz" │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ Softmax │ Convert to probabilities │
│ └─────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────┐ │
│ │ probs: [0.05, 0.004, 0.008, 0.85, ..., 0.00001] (sum = 1) │ │
│ │ "the" "cat" "dog" "Paris" "xyz" │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WEIGHT TYING: │
│ ───────────── │
│ │
│ Observation: Embedding and output are almost INVERSE operations! │
│ │
│ Embedding: "Given token 'cat', what's its vector?" │
│ → Look up row 3797 of embedding matrix │
│ │
│ Output: "Given this vector, which token does it represent?" │
│ → Compute similarity to all embedding rows │
│ │
│ SOLUTION: Use SAME weights for both! │
│ │
│ E = embedding matrix (vocab_size × d_model) │
│ │
│ Input: x = E[token_id] (lookup row) │
│ Output: logits = hidden × E^T (multiply by transpose) │
│ │
│ logit_i = hidden · E[i] = dot product with token i's embedding │
│ │
│ Interpretation: Tokens whose embeddings are SIMILAR to the hidden │
│ state get HIGH logits → HIGH probability of being generated. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY WEIGHT TYING WORKS: │
│ ─────────────────────── │
│ │
│ 1. PARAMETER SAVINGS │
│ Without tying: 32000 × 4096 × 2 = 262M params │
│ With tying: 32000 × 4096 = 131M params │
│ Saves 131M parameters! │
│ │
│ 2. REGULARIZATION │
│ Embeddings must be useful for BOTH input understanding and │
│ output generation. This constraint helps them learn more │
│ general, robust representations. │
│ │
│ 3. SEMANTIC CONSISTENCY │
│ The model's "understanding" of a word (input embedding) is the │
│ same as what it "means" to generate that word (output space). │
│ One unified semantic space for everything. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ USAGE: │
│ Weight tying is STANDARD. Nearly all modern LLMs use it: │
│ GPT-2, GPT-3, Llama, Llama 2, Mistral, Qwen, etc. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VIII: Encoder-Decoder vs Decoder-Only
The Three Transformer Architectures
The original transformer paper from 2017 introduced an encoder-decoder architecture for machine translation. Since then, three main architectural patterns have emerged, each suited to different tasks:
Encoder-only models (like BERT) process input bidirectionally—every position can attend to every other position. They're excellent for understanding tasks like classification, named entity recognition, and creating embeddings, but they can't generate text autoregressively.
Encoder-decoder models (like T5, BART) have separate components for understanding input (encoder, bidirectional) and generating output (decoder, autoregressive). The decoder attends to the encoder's output through cross-attention, allowing it to reference the input while generating.
Decoder-only models (like GPT, Llama, Claude) use a single autoregressive stack where each position can only attend to earlier positions. They treat input and output uniformly—everything is just a sequence of tokens being predicted.
The striking fact about modern LLMs is that decoder-only architectures have completely dominated. GPT-3, GPT-4, Claude, Llama, Llama 2, Llama 3, Mistral, Falcon, Qwen—virtually every major model is decoder-only. Why did this architecture win so decisively?
Why Decoder-Only Dominates
The dominance of decoder-only architectures wasn't obvious in 2017 or even 2019. BERT (encoder-only) achieved remarkable results on NLU benchmarks. T5 (encoder-decoder) showed strong performance on a wide range of tasks. The victory of decoder-only emerged gradually as models scaled and use cases broadened.
Simplicity: A decoder-only model has one type of attention: causal self-attention where each position attends to itself and earlier positions. An encoder-decoder model has three: bidirectional self-attention in the encoder, causal self-attention in the decoder, and cross-attention from decoder to encoder. This complexity adds engineering burden, potential for bugs, and optimization challenges. Simpler architectures are easier to scale, easier to optimize, and easier to reason about.
Unified input-output: In decoder-only models, input and output are processed identically—they're all tokens in a sequence, with each token predicted based on preceding tokens. This enables remarkable flexibility:
- In-context learning: Few-shot examples become part of the prompt, processed the same way as the actual query. The model naturally learns from examples in its context.
- Multi-turn conversation: Chat history is just more tokens. There's no special handling for "previous turns" vs "current input."
- Flexible task specification: Tasks are described in natural language within the prompt. No special architecture for different task types.
Encoder-decoder models have a hard boundary between "input" (processed by encoder) and "output" (generated by decoder). This forces decisions about what counts as input vs output and prevents the fluid mixing that makes decoder-only models so flexible.
Scaling efficiency: Given a fixed compute budget, decoder-only models achieve better loss than encoder-decoder models. This empirical finding from scaling experiments was decisive. The reasons aren't fully understood but likely relate to:
- Every token provides a training signal (predict the next token), not just tokens in the "output" portion
- No parameter overhead from cross-attention
- Simpler optimization landscape
KV cache efficiency: During autoregressive generation, decoder-only models cache key and value vectors for all past tokens and reuse them when generating each new token. This cache is straightforward: just append new KV pairs as tokens are generated.
Encoder-decoder models must maintain both the encoder's output (for cross-attention) and the decoder's KV cache. Cross-attention at every decoder layer adds memory and computation. The simpler caching story of decoder-only models makes them easier to optimize for inference.
The GPT-3 moment: Perhaps most importantly, GPT-3 demonstrated that decoder-only models could do everything. With enough scale and the right prompting, a decoder-only model could translate, summarize, answer questions, write code, and engage in conversation—all without any task-specific architecture. This generality, combined with the simplicity and scaling properties, sealed the victory.
┌─────────────────────────────────────────────────────────────────────────┐
│ WHY DECODER-ONLY DOMINATES │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THREE ARCHITECTURES: │
│ ──────────────────── │
│ │
│ ENCODER-ONLY (BERT): │
│ • Bidirectional attention—every position sees all others │
│ • Great for understanding: classification, NER, embeddings │
│ • Cannot generate text autoregressively │
│ │
│ ENCODER-DECODER (T5, BART): │
│ • Encoder: bidirectional processing of input │
│ • Decoder: autoregressive generation of output │
│ • Cross-attention: decoder attends to encoder output │
│ • Natural for sequence-to-sequence tasks │
│ │
│ DECODER-ONLY (GPT, Llama, Claude): │
│ • Single autoregressive stack │
│ • Each position attends only to earlier positions │
│ • Input and output treated identically │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY DECODER-ONLY WON: │
│ ───────────────────── │
│ │
│ 1. SIMPLICITY │
│ Encoder-decoder: 3 attention types (enc self, dec self, cross) │
│ Decoder-only: 1 attention type (causal self) │
│ │
│ Simpler = easier to scale, optimize, debug, reason about. │
│ │
│ 2. UNIFIED INPUT-OUTPUT │
│ Everything is just tokens in a sequence. This enables: │
│ │
│ • In-context learning: Few-shot examples in the prompt │
│ work naturally—they're processed like everything else │
│ │
│ • Multi-turn chat: History is just more tokens, no special │
│ handling needed │
│ │
│ • Task flexibility: Describe any task in natural language │
│ as part of the prompt │
│ │
│ 3. SCALING EFFICIENCY │
│ Empirical finding: Given fixed compute, decoder-only models │
│ achieve better loss than encoder-decoder models. │
│ │
│ Why? Every token is a prediction target (more training signal), │
│ no parameter overhead from cross-attention. │
│ │
│ 4. INFERENCE EFFICIENCY │
│ KV cache is simpler—just append new KV pairs. │
│ No encoder output to maintain, no cross-attention overhead. │
│ │
│ 5. THE GPT-3 DEMONSTRATION │
│ GPT-3 proved decoder-only could do EVERYTHING: │
│ Translation, summarization, QA, code, chat—all in one model. │
│ No task-specific architecture needed. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ HISTORICAL EVOLUTION: │
│ ───────────────────── │
│ │
│ 2017: Original Transformer (encoder-decoder for translation) │
│ 2018: BERT demonstrates encoder-only power for NLU │
│ 2018: GPT-1 shows decoder-only can learn from pretraining │
│ 2019: GPT-2 scales up decoder-only, shows broad capability │
│ 2020: GPT-3 (175B) → decoder-only can do EVERYTHING │
│ 2020: T5 shows encoder-decoder competitive but more complex │
│ 2021+: Decoder-only dominates: PaLM, Chinchilla, Claude │
│ 2023+: Llama, GPT-4, Claude 2, Mistral—all decoder-only │
│ │
│ The field converged on decoder-only for general-purpose LLMs. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHEN ENCODER-DECODER STILL MAKES SENSE: │
│ ─────────────────────────────────────── │
│ │
│ • Machine translation (clear input/output boundary) │
│ • Summarization (long input → short output) │
│ • When bidirectional input encoding is critical │
│ │
│ But for general-purpose LLMs, decoder-only won decisively. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part IX: Architectural Evolution
From GPT-2 to Llama 3: What Changed
The transformer architecture has evolved significantly since GPT-2 in 2019. While the basic structure remains—embeddings, transformer blocks with attention and FFN, output projection—almost every component has been refined.
Normalization: GPT-2 used LayerNorm. Modern models use RMSNorm, which is simpler, faster, and equally effective. This change happened around 2022-2023 with models like Llama.
Activation function: GPT-2 used GELU (a smooth approximation to ReLU). Modern models use SwiGLU, which adds gating for better expressiveness. The shift began with PaLM in 2022 and is now universal.
Positional encoding: GPT-2 used learned absolute position embeddings. Modern models use RoPE (Rotary Position Embeddings), which encodes relative position and extrapolates better. This shift also happened around 2022.
Attention efficiency: While GPT-2 used standard multi-head attention, large modern models like Llama 70B use Grouped-Query Attention (GQA), which shares key-value heads among multiple query heads to reduce KV cache memory.
FFN ratio: The original transformer used 4× expansion in FFN. Modern models with SwiGLU use ~2.67× (8/3) to maintain parameter parity while gaining the benefits of gating.
Bias terms: GPT-2 used bias terms in linear layers. Modern models often remove biases entirely—they're not needed and removal saves parameters and simplifies computation.
These changes seem incremental, but together they represent substantial improvements in training stability, inference efficiency, and model quality. A 7B parameter Llama 2 significantly outperforms a 7B GPT-2-style model, even trained on similar data, due to these architectural refinements.
┌─────────────────────────────────────────────────────────────────────────┐
│ ARCHITECTURAL EVOLUTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ GPT-2 (2019) → LLAMA 3 (2024) │
│ ────────────────────────────────────────────────────────────────────── │
│ │
│ Normalization: │
│ LayerNorm → RMSNorm │
│ (mean + var + bias) (RMS only, faster) │
│ │
│ Activation: │
│ GELU → SwiGLU │
│ (smooth ReLU approx) (gated, more expressive) │
│ │
│ Position encoding: │
│ Learned absolute → RoPE │
│ (fixed max length) (relative, extrapolates) │
│ │
│ Attention: │
│ Standard MHA → GQA (for large models) │
│ (all heads independent) (shared KV heads, less memory) │
│ │
│ FFN expansion: │
│ 4× → ~2.67× with SwiGLU │
│ (8d_model intermediate) (keeps param count, adds gating) │
│ │
│ Biases: │
│ Yes → No │
│ (in all linear layers) (removed for simplicity) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ IMPACT: │
│ A modern 7B model substantially outperforms a GPT-2-style 7B model, │
│ even with similar training data, due to these architectural advances.│
│ │
│ The changes seem incremental but compound to significant gains. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Summary
The transformer architecture, despite being introduced in 2017, remains the foundation of all modern large language models. Understanding its components deeply—from embeddings through attention and FFN to output—reveals why certain design choices work and how the field has evolved.
Key insights from this deep dive:
The FFN dominates parameter count: Approximately 67% of transformer parameters live in feed-forward networks, not attention. This explains why FFN optimization (SwiGLU, MoE) has such impact.
Residual connections enable depth: Without residual connections, training 32+ layer networks would be extremely difficult. The residual path provides a gradient highway that keeps information and gradients flowing.
Pre-norm beats post-norm for training stability: Placing normalization before sublayers creates cleaner gradient flow, enabling stable training without careful warmup schedules.
RMSNorm and SwiGLU are the modern standards: These refinements provide efficiency and quality gains with no downside, and are now used by virtually all new models.
Decoder-only won for generality: The simplicity, scaling properties, and flexibility of decoder-only architectures made them the clear choice for general-purpose LLMs.
Weight tying is elegant and effective: Using the same weights for input embedding and output projection saves parameters and improves regularization.
The architecture continues to evolve, with innovations like Mixture of Experts expanding capacity and new attention variants addressing efficiency. But the core transformer structure—embeddings, self-attention, feed-forward networks, residual connections, normalization—has proven remarkably robust and likely to remain central to language models for years to come.
Frequently Asked Questions
Related Articles
Attention Mechanisms: From Self-Attention to FlashAttention
A comprehensive deep dive into attention mechanisms—the core innovation powering modern LLMs. From the intuition behind self-attention to the engineering of FlashAttention, understand how transformers actually work.
LLM Pre-training: Building Foundation Models from Scratch
A comprehensive guide to pre-training large language models—from data curation and architecture decisions to scaling laws and distributed training infrastructure. Understanding how GPT, Llama, and other foundation models are built.
LLM Inference Optimization: From Quantization to Speculative Decoding
A comprehensive guide to optimizing LLM inference for production—covering quantization, attention optimization, batching strategies, and deployment frameworks.
Text Generation & Decoding Strategies: A Complete Guide
A comprehensive guide to how LLMs actually generate text—from greedy decoding to beam search, temperature scaling, nucleus sampling, speculative decoding, and structured generation. Master the techniques that control LLM output quality, creativity, and speed.