Distributed Training: How to Train 70B+ Parameter Models
A comprehensive deep dive into distributed training—how to train models that don't fit on a single GPU. Understand data parallelism, tensor parallelism, pipeline parallelism, ZeRO optimization, and the engineering behind training frontier LLMs.
Table of Contents
Why Distributed Training Matters
A 70B parameter model in FP16 requires 140GB just for weights—far more than any single GPU can hold. Training requires additional memory for gradients (140GB), optimizer states (280GB for Adam), and activations (variable, often 100GB+). Total: 500-700GB for training, versus 80GB on the largest GPUs.
Distributed training isn't optional for large models—it's required.
But distributing training is more complex than just "use more GPUs." Communication between GPUs can bottleneck throughput. Memory can be wasted through redundant storage. Load imbalances can leave GPUs idle. Understanding these tradeoffs is essential for efficient large-scale training.
This post covers the core distributed training techniques: data parallelism, tensor parallelism, pipeline parallelism, and ZeRO optimization. You'll understand how teams train 70B+ models and how to apply these techniques yourself.
Part I: The Memory Problem
What Consumes Memory During Training?
Before diving into solutions, let's understand what we're solving:
┌─────────────────────────────────────────────────────────────────────────┐
│ TRAINING MEMORY BREAKDOWN │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ FOR A 7B PARAMETER MODEL (BF16 mixed precision): │
│ ──────────────────────────────────────────────── │
│ │
│ COMPONENT MEMORY EXPLANATION │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ Model Parameters 14 GB 7B × 2 bytes (BF16) │
│ │
│ Gradients 14 GB Same size as params (BF16) │
│ │
│ Optimizer States 56 GB Adam needs 2 states per param: │
│ - FP32 params 28 GB • m (momentum): 7B × 4 bytes │
│ - m (momentum) 14 GB • v (variance): 7B × 4 bytes │
│ - v (variance) 14 GB • FP32 copy of weights │
│ │
│ Activations Variable Depends on batch size & seq length │
│ - Batch 1, 2K seq ~8 GB Stored for backward pass │
│ - Batch 8, 2K seq ~64 GB Scales with batch size │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TOTAL FOR 7B MODEL: │
│ │
│ Parameters: 14 GB │
│ Gradients: 14 GB │
│ Optimizer: 56 GB │
│ Activations: ~32 GB (typical) │
│ ────────────────────── │
│ TOTAL: ~116 GB │
│ │
│ This exceeds a single 80GB A100! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SCALING TO 70B: │
│ ─────────────── │
│ │
│ Parameters: 140 GB │
│ Gradients: 140 GB │
│ Optimizer: 560 GB │
│ Activations: ~300 GB │
│ ────────────────────── │
│ TOTAL: ~1.1 TB │
│ │
│ Would need 14× 80GB GPUs just for memory! │
│ (Before considering any parallelism efficiency) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Memory Optimization Techniques
Before going multi-GPU, there are single-GPU optimizations:
┌─────────────────────────────────────────────────────────────────────────┐
│ SINGLE-GPU MEMORY OPTIMIZATIONS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. MIXED PRECISION TRAINING (BF16/FP16): │
│ ───────────────────────────────────────── │
│ │
│ Keep master weights in FP32, compute in BF16/FP16. │
│ Saves ~50% memory for params/grads, maintains accuracy. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 2. GRADIENT CHECKPOINTING: │
│ ────────────────────────── │
│ │
│ Don't store all activations. Recompute them during backward. │
│ │
│ Without checkpointing (store all): │
│ Layer 1 → save act1 → Layer 2 → save act2 → ... → Layer 32 │
│ Memory: O(num_layers × batch × seq × hidden) │
│ │
│ With checkpointing (recompute): │
│ Layer 1 → Layer 2 → ... → Layer 32 (only save checkpoints) │
│ Backward: recompute acts from checkpoints │
│ Memory: O(sqrt(num_layers) × batch × seq × hidden) │
│ │
│ Tradeoff: ~30% more compute, ~5× less activation memory │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 3. GRADIENT ACCUMULATION: │
│ ───────────────────────── │
│ │
│ Can't fit batch size 32? Use batch 4, accumulate 8 steps. │
│ │
│ for i, batch in enumerate(data): │
│ loss = model(batch) / accumulation_steps │
│ loss.backward() │
│ if (i + 1) % accumulation_steps == 0: │
│ optimizer.step() │
│ optimizer.zero_grad() │
│ │
│ Same effective batch size, lower memory. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 4. CPU OFFLOADING: │
│ ────────────────── │
│ │
│ Store optimizer states in CPU RAM, move to GPU when needed. │
│ Slower but enables larger models on single GPU. │
│ Used by: DeepSpeed ZeRO-Offload │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part II: Data Parallelism
The Simplest Approach
Data parallelism is the most straightforward distributed training technique: replicate the model on each GPU, split the data batch across GPUs.
┌─────────────────────────────────────────────────────────────────────────┐
│ DATA PARALLELISM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CONCEPT: │
│ ──────────── │
│ │
│ Each GPU has a COMPLETE copy of the model. │
│ Data batch is SPLIT across GPUs. │
│ Gradients are SYNCHRONIZED after each step. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUALIZATION: │
│ ────────────── │
│ │
│ Global batch: [sample1, sample2, sample3, sample4, sample5, sample6] │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ SPLIT DATA │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │
│ ├─────────────────────┤ ├─────────────────┤ ├─────────────────┤ │
│ │ Model (full copy) │ │ Model (copy) │ │ Model (copy) │ │
│ │ │ │ │ │ │ │
│ │ Data: [s1, s2] │ │ Data: [s3, s4] │ │ Data: [s5, s6] │ │
│ │ │ │ │ │ │ │
│ │ Forward → Loss │ │ Forward → Loss │ │ Forward → Loss │ │
│ │ Backward → Grads │ │ Backward → Grads│ │ Backward → Grads│ │
│ └─────────────────────┘ └─────────────────┘ └─────────────────┘ │
│ │ │ │ │
│ └──────────────────────┼───────────────────┘ │
│ ▼ │
│ ┌─────────────────────────┐ │
│ │ ALL-REDUCE GRADIENTS │ │
│ │ (average gradients) │ │
│ └─────────────────────────┘ │
│ │ │
│ ┌─────────────┼─────────────┐ │
│ ▼ ▼ ▼ │
│ GPU 0 update GPU 1 update GPU 2 update │
│ (identical) (identical) (identical) │
│ │
│ All GPUs end up with identical model weights. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SCALING PROPERTIES: │
│ ─────────────────── │
│ │
│ N GPUs → N× throughput (ideal) │
│ │
│ In practice: │
│ • Communication overhead reduces scaling │
│ • Typical: 90-95% scaling efficiency on good hardware │
│ • 8 GPUs → ~7.5× throughput │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ LIMITATIONS: │
│ ──────────── │
│ │
│ Model must fit on EACH GPU! │
│ 7B model: 116GB needed → doesn't fit on 80GB GPU │
│ Data parallelism alone can't train 7B on A100. │
│ │
│ For larger models, need model parallelism techniques. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
All-Reduce Communication
The gradient synchronization uses a collective operation called "all-reduce":
┌─────────────────────────────────────────────────────────────────────────┐
│ ALL-REDUCE OPERATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ALL-REDUCE: Each GPU starts with local data, ends with global sum. │
│ │
│ NAIVE APPROACH (bad): │
│ ───────────────────── │
│ Send all gradients to GPU 0, sum, broadcast back. │
│ │
│ Time: O(N × data_size) - doesn't scale │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RING ALL-REDUCE (efficient): │
│ ──────────────────────────── │
│ │
│ Arrange GPUs in a ring, pass data around incrementally. │
│ │
│ Phase 1: Reduce-Scatter │
│ GPU 0 → GPU 1 → GPU 2 → GPU 3 → GPU 0 (ring) │
│ Each GPU sends 1/N of its data, receives and adds 1/N from neighbor │
│ After N-1 steps: each GPU has 1/N of the total sum │
│ │
│ Phase 2: All-Gather │
│ Same ring pattern, but now sharing the partial sums │
│ After N-1 steps: each GPU has the complete sum │
│ │
│ Time: O(data_size) - independent of N! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUALIZATION: │
│ ────────────── │
│ │
│ 4 GPUs, gradients = [g0, g1, g2, g3] │
│ Goal: each GPU gets (g0 + g1 + g2 + g3) / 4 │
│ │
│ Initial: │
│ GPU 0: [g0] GPU 1: [g1] GPU 2: [g2] GPU 3: [g3] │
│ │
│ Reduce-Scatter (split into 4 chunks, ring-reduce each): │
│ After: Each GPU has 1/4 of total sum │
│ GPU 0: [Σchunk0] GPU 1: [Σchunk1] GPU 2: [Σchunk2] GPU 3: [Σchunk3]│
│ │
│ All-Gather (share chunks around ring): │
│ After: Each GPU has all 4 chunks │
│ GPU 0: [Σall] GPU 1: [Σall] GPU 2: [Σall] GPU 3: [Σall] │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMMUNICATION VOLUME: │
│ ───────────────────── │
│ │
│ For gradient size G and N GPUs: │
│ Ring all-reduce: 2 × G × (N-1)/N bytes per GPU │
│ │
│ As N → ∞: approaches 2G bytes (constant!) │
│ This is why data parallelism scales well. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part III: Tensor Parallelism
Splitting Layers Across GPUs
When a single layer is too large, split it across GPUs:
┌─────────────────────────────────────────────────────────────────────────┐
│ TENSOR PARALLELISM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CONCEPT: │
│ ──────────── │
│ │
│ Split individual LAYERS across GPUs. │
│ Each GPU holds part of each layer. │
│ Requires communication within each layer. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ LINEAR LAYER SPLITTING: │
│ ─────────────────────── │
│ │
│ Consider: Y = XW where W is (d_in × d_out) │
│ │
│ COLUMN PARALLEL (split output dimension): │
│ ────────────────────────────────────────── │
│ │
│ W = [W1 | W2] (split columns) │
│ │
│ GPU 0: Y1 = X × W1 (X is replicated) │
│ GPU 1: Y2 = X × W2 │
│ │
│ Result: [Y1 | Y2] = Y (concatenate outputs) │
│ │
│ X (replicated) │
│ │ │
│ ┌─────┴─────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────┐ ┌──────┐ │
│ │ W1 │ │ W2 │ │
│ │(GPU0)│ │(GPU1)│ │
│ └──────┘ └──────┘ │
│ │ │ │
│ ▼ ▼ │
│ Y1 Y2 │
│ │ │ │
│ └─────┬─────┘ │
│ ▼ │
│ [Y1 | Y2] = Y │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ROW PARALLEL (split input dimension): │
│ ───────────────────────────────────── │
│ │
│ W = [W1] (split rows) │
│ [W2] │
│ │
│ X = [X1 | X2] (split input) │
│ │
│ GPU 0: Y1 = X1 × W1 │
│ GPU 1: Y2 = X2 × W2 │
│ │
│ Result: Y = Y1 + Y2 (sum partial results - requires all-reduce!) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMBINING FOR TRANSFORMER LAYERS: │
│ ───────────────────────────────── │
│ │
│ ATTENTION: │
│ • Q, K, V projections: Column parallel (split heads) │
│ • Output projection: Row parallel │
│ │
│ FFN: │
│ • First linear: Column parallel (split hidden) │
│ • Second linear: Row parallel (combine) │
│ │
│ This minimizes communication: only 2 all-reduces per layer. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Megatron-Style Tensor Parallelism
┌─────────────────────────────────────────────────────────────────────────┐
│ MEGATRON TENSOR PARALLELISM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ MEGATRON'S KEY INSIGHT: │
│ ─────────────────────── │
│ │
│ Arrange column-parallel and row-parallel to minimize communication. │
│ │
│ Column parallel → Activation (no comm) → Row parallel → All-reduce │
│ │
│ Only ONE all-reduce between consecutive layers! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ATTENTION BLOCK (2-way tensor parallel): │
│ ───────────────────────────────────────── │
│ │
│ Input X (replicated on both GPUs) │
│ │ │
│ ├───────────────────┐ │
│ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ │
│ │ QKV │ │ QKV │ │
│ │(heads │ │(heads │ Column parallel: │
│ │ 0-15) │ │ 16-31) │ Split attention heads │
│ │ GPU 0 │ │ GPU 1 │ │
│ └─────────┘ └─────────┘ │
│ │ │ │
│ ▼ ▼ │
│ Attention Attention (independent per GPU) │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ │
│ │ Output │ │ Output │ Row parallel: │
│ │ Proj │ │ Proj │ Each produces partial result │
│ │ (part) │ │ (part) │ │
│ └─────────┘ └─────────┘ │
│ │ │ │
│ └─────────┬─────────┘ │
│ ▼ │
│ ALL-REDUCE │
│ (sum partials) │
│ │ │
│ ▼ │
│ Output Y │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ FFN BLOCK: │
│ ────────── │
│ │
│ Input X (replicated) │
│ │ │
│ ├───────────────────┐ │
│ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ │
│ │ FC1 │ │ FC1 │ Column parallel: │
│ │(half of │ │(other │ Split intermediate dim │
│ │ hidden) │ │ half) │ │
│ └─────────┘ └─────────┘ │
│ │ │ │
│ ▼ ▼ │
│ GeLU GeLU (independent) │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ │
│ │ FC2 │ │ FC2 │ Row parallel │
│ │ (part) │ │ (part) │ │
│ └─────────┘ └─────────┘ │
│ │ │ │
│ └─────────┬─────────┘ │
│ ▼ │
│ ALL-REDUCE │
│ │ │
│ ▼ │
│ Output Y │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TOTAL COMMUNICATION PER TRANSFORMER LAYER: │
│ ────────────────────────────────────────── │
│ │
│ Forward: 2 all-reduces (attention + FFN) │
│ Backward: 2 all-reduces │
│ Total: 4 all-reduces per layer per step │
│ │
│ Communication volume: 4 × batch × seq × hidden × dtype │
│ │
└─────────────────────────────────────────────────────────────────────────┘
When to Use Tensor Parallelism
┌─────────────────────────────────────────────────────────────────────────┐
│ TENSOR PARALLELISM TRADEOFFS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ADVANTAGES: │
│ ─────────── │
│ ✓ Reduces memory per GPU (model fits) │
│ ✓ Low latency (all GPUs work on same batch) │
│ ✓ No pipeline bubbles │
│ │
│ DISADVANTAGES: │
│ ────────────── │
│ ✗ High communication frequency (every layer) │
│ ✗ Requires fast interconnect (NVLink) │
│ ✗ Doesn't scale beyond ~8 GPUs efficiently │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SCALING EFFICIENCY: │
│ ─────────────────── │
│ │
│ TP Degree Communication Efficiency (NVLink) │
│ ───────────────────────────────────────────────────────── │
│ 2 Moderate 95-98% │
│ 4 Higher 90-95% │
│ 8 Very high 80-90% │
│ 16 Extreme 60-75% │
│ │
│ Beyond 8-way TP, communication dominates. │
│ Better to combine with other parallelism types. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ HARDWARE REQUIREMENTS: │
│ ────────────────────── │
│ │
│ Tensor parallelism NEEDS fast interconnect: │
│ │
│ • NVLink: 600 GB/s (H100), 300 GB/s (A100) │
│ → Tensor parallelism works well within a node │
│ │
│ • PCIe: 64 GB/s │
│ → Too slow for tensor parallelism │
│ │
│ • InfiniBand: 50-400 GB/s │
│ → Usable for TP across nodes, but slower │
│ │
│ Rule: Tensor parallelism within node, other methods across nodes. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part IV: Pipeline Parallelism
Splitting Layers Across Stages
Pipeline parallelism divides the model into sequential stages, each on different GPUs:
┌─────────────────────────────────────────────────────────────────────────┐
│ PIPELINE PARALLELISM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CONCEPT: │
│ ──────────── │
│ │
│ Split model into STAGES (groups of consecutive layers). │
│ Each stage on a different GPU. │
│ Data flows through stages like a pipeline. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MODEL SPLIT (32 layers, 4 GPUs): │
│ ──────────────────────────────── │
│ │
│ GPU 0 (Stage 0): Embedding + Layers 0-7 │
│ GPU 1 (Stage 1): Layers 8-15 │
│ GPU 2 (Stage 2): Layers 16-23 │
│ GPU 3 (Stage 3): Layers 24-31 + LM Head │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ NAIVE PIPELINE (Sequential - Bad): │
│ ────────────────────────────────── │
│ │
│ Time → │
│ GPU 0: [F0]─────────────────────────────────[B0] │
│ GPU 1: [F1]─────────────────────────[B1] │
│ GPU 2: [F2]─────────────────[B2] │
│ GPU 3: [F3]────────[B3] │
│ │ │ │
│ Loss Backward starts │
│ │
│ F = Forward, B = Backward │
│ │
│ Problem: HUGE "bubble" where GPUs sit idle! │
│ Only 1 GPU active at a time. 75% idle with 4 GPUs! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MICRO-BATCHING (GPipe): │
│ ─────────────────────── │
│ │
│ Split batch into micro-batches, pipeline them: │
│ │
│ Batch: [B1, B2, B3, B4] (4 micro-batches) │
│ │
│ Time → │
│ GPU 0: [F0₁][F0₂][F0₃][F0₄] [B0₄][B0₃][B0₂][B0₁] │
│ GPU 1: [F1₁][F1₂][F1₃][F1₄] [B1₄][B1₃][B1₂][B1₁] │
│ GPU 2: [F2₁][F2₂][F2₃][F2₄][B2₄][B2₃][B2₂][B2₁] │
│ GPU 3: [F3₁][F3₂][F3₃][F3₄][B3₁][B3₂][B3₃][B3₄] │
│ │
│ Much better! GPUs stay busier. │
│ But still has bubble at start and end. │
│ │
│ Bubble fraction ≈ (num_stages - 1) / num_micro_batches │
│ With 4 stages, 16 micro-batches: bubble = 3/16 = 19% │
│ │
└─────────────────────────────────────────────────────────────────────────┘
1F1B Schedule
Modern pipeline parallelism uses 1F1B (one forward, one backward) scheduling:
┌─────────────────────────────────────────────────────────────────────────┐
│ 1F1B PIPELINE SCHEDULE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1F1B: Interleave forward and backward passes. │
│ Reduces memory by not storing all activations simultaneously. │
│ │
│ Time → │
│ GPU 0: [F₁][F₂][F₃][F₄][B₁][B₂][B₃][B₄] │
│ GPU 1: [F₁][F₂][F₃][F₄|B₁][B₂][B₃][B₄] │
│ GPU 2: [F₁][F₂][F₃|B₁][F₄|B₂][B₃][B₄] │
│ GPU 3: [F₁][F₂|B₁][F₃|B₂][F₄|B₃][B₄] │
│ │
│ Key insight: Start backward while still doing forward! │
│ Once GPU 3 finishes F₁, it can do B₁ immediately. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MEMORY COMPARISON: │
│ ────────────────── │
│ │
│ GPipe (all F then all B): │
│ Must store activations for ALL micro-batches during forward. │
│ Memory: O(num_micro_batches × activations_per_micro) │
│ │
│ 1F1B: │
│ Store activations for at most num_stages micro-batches. │
│ Memory: O(num_stages × activations_per_micro) │
│ │
│ With 16 micro-batches and 4 stages: │
│ GPipe: 16× activation memory │
│ 1F1B: 4× activation memory (4× reduction!) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ INTERLEAVED STAGES (Further optimization): │
│ ─────────────────────────────────────────── │
│ │
│ Instead of contiguous layers per GPU, interleave: │
│ │
│ Standard (4 GPUs, 32 layers): │
│ GPU 0: Layers 0-7 │
│ GPU 1: Layers 8-15 │
│ GPU 2: Layers 16-23 │
│ GPU 3: Layers 24-31 │
│ │
│ Interleaved (virtual stages): │
│ GPU 0: Layers 0-3, 16-19 │
│ GPU 1: Layers 4-7, 20-23 │
│ GPU 2: Layers 8-11, 24-27 │
│ GPU 3: Layers 12-15, 28-31 │
│ │
│ More stages = more pipeline opportunities = smaller bubble! │
│ Tradeoff: More communication (cross-GPU between layer groups). │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part V: ZeRO - Zero Redundancy Optimizer
Eliminating Memory Redundancy
In data parallelism, EVERY GPU stores a complete copy of model, gradients, and optimizer states. ZeRO partitions these across GPUs:
┌─────────────────────────────────────────────────────────────────────────┐
│ ZeRO OPTIMIZATION STAGES │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ DATA PARALLEL REDUNDANCY: │
│ ───────────────────────── │
│ │
│ With 4 GPUs, traditional DP stores on EACH GPU: │
│ • Full model parameters │
│ • Full gradients │
│ • Full optimizer states │
│ │
│ 4× redundant storage! Massive waste. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ZeRO STAGE 1 (Optimizer State Partitioning): │
│ ───────────────────────────────────────────── │
│ │
│ Partition optimizer states across GPUs. │
│ Each GPU stores 1/N of optimizer states. │
│ │
│ Standard DP (4 GPUs, 7B model): │
│ Each GPU: 56 GB optimizer states │
│ │
│ ZeRO-1: │
│ Each GPU: 14 GB optimizer states (1/4) │
│ │
│ Memory saved: 75% of optimizer memory │
│ │
│ Communication: After gradient all-reduce, each GPU updates its │
│ partition, then all-gather to sync parameters. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ZeRO STAGE 2 (+ Gradient Partitioning): │
│ ──────────────────────────────────────── │
│ │
│ Also partition gradients across GPUs. │
│ Each GPU only stores gradients for params it will update. │
│ │
│ ZeRO-2: │
│ Each GPU: 14 GB optimizer + 3.5 GB gradients (vs 14 GB) │
│ │
│ Communication: Reduce-scatter gradients (not all-reduce). │
│ Each GPU gets only the gradients it needs. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ZeRO STAGE 3 (+ Parameter Partitioning): │
│ ───────────────────────────────────────── │
│ │
│ Partition everything: optimizer, gradients, AND parameters. │
│ Each GPU stores only 1/N of the model. │
│ │
│ ZeRO-3: │
│ Each GPU: 14 GB optimizer + 3.5 GB gradients + 3.5 GB params │
│ (vs 14 + 14 + 56 = 84 GB in standard DP) │
│ │
│ Communication: All-gather params before forward/backward. │
│ Higher communication, but enables training much larger models. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MEMORY COMPARISON (7B model, 4 GPUs): │
│ │
│ Method Params Grads Optimizer Total/GPU Reduction │
│ ───────────────────────────────────────────────────────────────── │
│ Standard DP 14 GB 14 GB 56 GB 84 GB 1× │
│ ZeRO-1 14 GB 14 GB 14 GB 42 GB 2× │
│ ZeRO-2 14 GB 3.5 GB 14 GB 31.5 GB 2.7× │
│ ZeRO-3 3.5 GB 3.5 GB 14 GB 21 GB 4× │
│ │
│ ZeRO-3 enables 4× larger models with same GPU memory! │
│ │
└─────────────────────────────────────────────────────────────────────────┘
ZeRO-3 Deep Dive
┌─────────────────────────────────────────────────────────────────────────┐
│ ZeRO-3 PARAMETER PARTITIONING │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ HOW PARAMETER PARTITIONING WORKS: │
│ ───────────────────────────────── │
│ │
│ Model has layers: [L0, L1, L2, L3, L4, L5, L6, L7] │
│ 4 GPUs │
│ │
│ Partitioning: │
│ GPU 0 stores: L0, L4 │
│ GPU 1 stores: L1, L5 │
│ GPU 2 stores: L2, L6 │
│ GPU 3 stores: L3, L7 │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ FORWARD PASS: │
│ ───────────── │
│ │
│ To compute L0(x), ALL GPUs need L0's parameters. │
│ │
│ Step 1: GPU 0 broadcasts L0 params to all GPUs (all-gather) │
│ Step 2: All GPUs compute L0(x) in parallel │
│ Step 3: Discard L0 params on GPUs 1,2,3 (keep on GPU 0) │
│ Step 4: GPU 1 broadcasts L1 params │
│ Step 5: All GPUs compute L1(...) │
│ ... repeat for each layer │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ BACKWARD PASS: │
│ ───────────── │
│ │
│ Same pattern in reverse: │
│ 1. All-gather params for layer │
│ 2. Compute gradients │
│ 3. Reduce-scatter gradients (each GPU gets its partition) │
│ 4. Update local optimizer states │
│ 5. Update local parameters │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMMUNICATION OVERHEAD: │
│ ─────────────────────── │
│ │
│ ZeRO-3 has 3× communication vs standard DP: │
│ │
│ Standard DP: 2× model size (all-reduce gradients) │
│ │
│ ZeRO-3: │
│ • 1× all-gather params (forward) │
│ • 1× all-gather params (backward) │
│ • 1× reduce-scatter gradients │
│ • 1× all-gather updated params │
│ Total: 4× model size │
│ │
│ But: Memory savings often more valuable than bandwidth. │
│ Can train larger models → better results. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VI: FSDP - Fully Sharded Data Parallelism
PyTorch's ZeRO Implementation
FSDP is PyTorch's native implementation of ZeRO-3-like sharding:
┌─────────────────────────────────────────────────────────────────────────┐
│ FSDP (Fully Sharded Data Parallel) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ FSDP = PyTorch's native ZeRO-3 │
│ ──────────────────────────── │
│ │
│ Key features: │
│ • Shard params, gradients, optimizer states │
│ • Automatic parameter gathering during forward/backward │
│ • Works with PyTorch ecosystem │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ BASIC USAGE: │
│ ──────────── │
│ │
│ from torch.distributed.fsdp import ( │
│ FullyShardedDataParallel as FSDP, │
│ ShardingStrategy, │
│ ) │
│ │
│ # Wrap model with FSDP │
│ model = FSDP( │
│ model, │
│ sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 │
│ # or SHARD_GRAD_OP for ZeRO-2 │
│ # or NO_SHARD for standard DDP │
│ ) │
│ │
│ # Training loop unchanged │
│ for batch in dataloader: │
│ loss = model(batch) │
│ loss.backward() │
│ optimizer.step() │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SHARDING STRATEGIES: │
│ ──────────────────── │
│ │
│ FULL_SHARD (ZeRO-3): │
│ • Shard params, grads, optimizer │
│ • Maximum memory efficiency │
│ • Highest communication │
│ │
│ SHARD_GRAD_OP (ZeRO-2): │
│ • Shard grads, optimizer │
│ • Params replicated │
│ • Less communication │
│ │
│ NO_SHARD (DDP): │
│ • Nothing sharded │
│ • Standard data parallel │
│ • Lowest communication │
│ │
│ HYBRID_SHARD: │
│ • Full shard within node │
│ • Replicate across nodes │
│ • Balance memory and communication │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ AUTO WRAPPING: │
│ ────────────── │
│ │
│ FSDP wraps at module granularity. Control with policies: │
│ │
│ from torch.distributed.fsdp.wrap import ( │
│ transformer_auto_wrap_policy, │
│ ) │
│ │
│ # Wrap each transformer block separately │
│ policy = transformer_auto_wrap_policy( │
│ transformer_layer_cls={TransformerBlock}, │
│ ) │
│ │
│ model = FSDP(model, auto_wrap_policy=policy) │
│ │
│ This ensures each block's params are sharded as a unit. │
│ Communication happens between blocks, not within. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VII: Combining Parallelism Strategies
3D Parallelism
Real large-scale training combines multiple parallelism types:
┌─────────────────────────────────────────────────────────────────────────┐
│ 3D PARALLELISM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE CONCEPT: │
│ ──────────── │
│ │
│ Combine: │
│ • Data Parallelism (DP): Replicate model, split data │
│ • Tensor Parallelism (TP): Split layers within device group │
│ • Pipeline Parallelism (PP): Split model stages across groups │
│ │
│ Each addresses different constraints: │
│ • TP: Reduces per-GPU memory, high bandwidth needed │
│ • PP: Reduces per-GPU memory, tolerates lower bandwidth │
│ • DP: Increases throughput, requires model to fit │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXAMPLE: Training 70B on 64 GPUs (8 nodes × 8 GPUs): │
│ ───────────────────────────────────────────────────── │
│ │
│ Configuration: │
│ • TP = 8 (within each node, using NVLink) │
│ • PP = 4 (across 4 node groups) │
│ • DP = 2 (2 replicas of the pipeline) │
│ │
│ Total: 8 × 4 × 2 = 64 GPUs │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUAL: │
│ │
│ Data Parallel Replica 1 │
│ ┌──────────────────────────────────────────────┐ │
│ │ │ │
│ PP │ TP Group 0 TP Group 1 ... │ │
│ Stage │ (8 GPUs) (8 GPUs) │ │
│ 0 │ ┌────────┐ ┌────────┐ │ │
│ │ │Node 0 │ │Node 4 │ │ │
│ │ │8×A100 │ │8×A100 │ │ │
│ │ │ └────────┘ └────────┘ │ │
│ │ │ │ │ │ │
│ │ │ ▼ ▼ │ │
│ 1 │ ┌────────┐ ┌────────┐ │ │
│ │ │Node 1 │ │Node 5 │ │ │
│ │ │ └────────┘ └────────┘ │ │
│ │ │ │ │ │ │
│ │ │ ▼ ▼ │ │
│ 2 │ ┌────────┐ ┌────────┐ │ │
│ │ │Node 2 │ │Node 6 │ │ │
│ │ │ └────────┘ └────────┘ │ │
│ │ │ │ │ │ │
│ │ │ ▼ ▼ │ │
│ 3 │ ┌────────┐ ┌────────┐ │ │
│ │ │Node 3 │ │Node 7 │ │ │
│ │ └────────┘ └────────┘ │ │
│ │ │ │
│ └──────────────────────────────────────────────┘ │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMMUNICATION PATTERNS: │
│ ─────────────────────── │
│ │
│ Tensor Parallelism (within node): │
│ • All-reduce every layer │
│ • Requires NVLink (600 GB/s on H100) │
│ │
│ Pipeline Parallelism (across nodes): │
│ • Point-to-point between stages │
│ • Uses InfiniBand (400 Gb/s) │
│ • Lower bandwidth OK due to less frequent comm │
│ │
│ Data Parallelism (across replicas): │
│ • All-reduce gradients after backward │
│ • Uses InfiniBand │
│ • Overlaps with compute │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Choosing Parallelism Configuration
┌─────────────────────────────────────────────────────────────────────────┐
│ PARALLELISM CONFIGURATION GUIDE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ DECISION FRAMEWORK: │
│ ─────────────────── │
│ │
│ 1. Does model fit on 1 GPU? │
│ YES → Use Data Parallel only (or FSDP for memory efficiency) │
│ NO → Continue to step 2 │
│ │
│ 2. Does model fit with TP within node? │
│ Try TP = 2, 4, or 8 (NVLink connected GPUs) │
│ YES → Use TP + DP │
│ NO → Continue to step 3 │
│ │
│ 3. Add Pipeline Parallelism │
│ PP = model_size / (memory_per_gpu × TP) │
│ Use 1F1B schedule with micro-batching │
│ │
│ 4. Fill remaining GPUs with Data Parallel │
│ DP = total_gpus / (TP × PP) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXAMPLE CONFIGURATIONS: │
│ │
│ Model Size GPUs TP PP DP Framework │
│ ───────────────────────────────────────────────────────────────── │
│ 7B 8 1 1 8 DDP or FSDP │
│ 13B 8 2 1 4 TP + DP │
│ 30B 16 4 2 2 TP + PP + DP │
│ 70B 64 8 4 2 Full 3D │
│ 175B 512 8 16 4 Full 3D │
│ 500B+ 1000+ 8 32+ 4+ Full 3D + Expert Parallel │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RULES OF THUMB: │
│ ─────────────── │
│ │
│ • TP within node only (NVLink required) │
│ • TP ≤ 8 (diminishing returns beyond) │
│ • PP micro-batches ≥ 4× PP stages (reduce bubble) │
│ • DP as large as possible (best scaling) │
│ • Memory per GPU: model/(TP×PP) + activations + optimizer │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMMON MISTAKES: │
│ ─────────────── │
│ │
│ ✗ Using TP across nodes (too slow without NVLink) │
│ ✗ Too few micro-batches for PP (large bubbles) │
│ ✗ Not enabling gradient checkpointing (OOM on activations) │
│ ✗ Wrong TP degree (not matching attention head count) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Part VIII: Practical Implementation
Using DeepSpeed
DeepSpeed is the most popular library for large-scale training:
# DeepSpeed configuration for 70B model training
deepspeed_config = {
# ZeRO Stage 3 for maximum memory efficiency
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu", # Offload optimizer to CPU
"pin_memory": True
},
"offload_param": {
"device": "cpu", # Offload params to CPU (optional)
"pin_memory": True
},
"overlap_comm": True, # Overlap communication with compute
"contiguous_gradients": True,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
},
# Mixed precision training
"bf16": {
"enabled": True
},
# Gradient accumulation
"gradient_accumulation_steps": 8,
# Optimizer
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5,
"betas": [0.9, 0.95],
"weight_decay": 0.1
}
},
# Learning rate scheduler
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 100,
"total_num_steps": 10000
}
},
# Training batch size
"train_micro_batch_size_per_gpu": 1,
"gradient_clipping": 1.0,
}
# Launch training
# deepspeed --num_gpus=8 train.py --deepspeed_config config.json
Using Megatron-LM
For tensor and pipeline parallelism:
# Megatron-LM training launch command
GPUS_PER_NODE=8
NNODES=8
NODE_RANK=$SLURM_PROCID
MASTER_ADDR=$SLURM_SUBMIT_HOST
MASTER_PORT=6000
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
# Model parallelism configuration
TENSOR_PARALLEL=8 # Within node (NVLink)
PIPELINE_PARALLEL=4 # Across nodes
# Data parallel is implicit: 64 / (8 * 4) = 2
MODEL_ARGS="
--num-layers 80 \
--hidden-size 8192 \
--num-attention-heads 64 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--tensor-model-parallel-size $TENSOR_PARALLEL \
--pipeline-model-parallel-size $PIPELINE_PARALLEL \
"
TRAINING_ARGS="
--micro-batch-size 1 \
--global-batch-size 1024 \
--train-iters 100000 \
--lr 1.5e-4 \
--min-lr 1.5e-5 \
--lr-decay-iters 100000 \
--lr-warmup-iters 2000 \
--bf16 \
--use-flash-attn \
--recompute-activations \
"
torchrun $DISTRIBUTED_ARGS \
pretrain_gpt.py \
$MODEL_ARGS \
$TRAINING_ARGS
Part IX: Recent Innovations (2024-2025)
FSDP2 - Next-Generation Sharding
PyTorch introduced FSDP2 in 2024 with significant improvements over the original FSDP:
┌─────────────────────────────────────────────────────────────────────────┐
│ FSDP2 IMPROVEMENTS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ KEY CHANGES FROM FSDP1: │
│ ─────────────────────── │
│ │
│ 1. PER-PARAMETER SHARDING (not per-module): │
│ FSDP1: Wraps entire modules, all params sharded together │
│ FSDP2: Each parameter independently sharded │
│ │
│ Benefits: │
│ • Finer-grained memory control │
│ • Better support for heterogeneous param sizes │
│ • Easier mixed-precision per-parameter │
│ │
│ 2. DTensor INTEGRATION: │
│ Built on DTensor (Distributed Tensor) abstraction │
│ • Unified sharding specification language │
│ • Cleaner composition with tensor parallelism │
│ • Better debugging and visualization │
│ │
│ 3. EXPLICIT PREFETCH CONTROL: │
│ FSDP1: Fixed prefetch behavior │
│ FSDP2: User-controlled prefetch policies │
│ │
│ model.set_prefetch_policy(PrefetchPolicy.BACKWARD) │
│ │
│ 4. COMPOSABLE API: │
│ Designed to compose cleanly with: │
│ • Tensor Parallel (tp.parallelize_module) │
│ • Pipeline Parallel │
│ • Activation checkpointing │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ FSDP2 BASIC USAGE: │
│ ────────────────── │
│ │
│ from torch.distributed._composable.fsdp import fully_shard │
│ from torch.distributed.device_mesh import DeviceMesh │
│ │
│ # Create device mesh for multi-dimensional parallelism │
│ mesh = DeviceMesh("cuda", torch.arange(8).reshape(2, 4)) │
│ # First dim: data parallel (2), second dim: tensor parallel (4) │
│ │
│ # Shard model parameters │
│ for layer in model.transformer_blocks: │
│ fully_shard(layer, mesh=mesh["dp"]) # Shard on DP dimension │
│ fully_shard(model, mesh=mesh["dp"]) │
│ │
│ # Training loop │
│ for batch in dataloader: │
│ loss = model(batch) │
│ loss.backward() │
│ optimizer.step() │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMPOSING WITH TENSOR PARALLEL: │
│ ──────────────────────────────── │
│ │
│ from torch.distributed.tensor.parallel import parallelize_module │
│ │
│ # 2D parallelism: TP + FSDP │
│ mesh_2d = DeviceMesh("cuda", torch.arange(8).reshape(2, 4)) │
│ tp_mesh = mesh_2d["tp"] # 4-way tensor parallel │
│ dp_mesh = mesh_2d["dp"] # 2-way data parallel (FSDP) │
│ │
│ # First apply tensor parallelism │
│ parallelize_module(model, tp_mesh, tp_plan) │
│ │
│ # Then apply FSDP (shards the TP-split params) │
│ for layer in model.layers: │
│ fully_shard(layer, mesh=dp_mesh) │
│ fully_shard(model, mesh=dp_mesh) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
torch.compile Integration
FSDP2 works seamlessly with torch.compile for additional speedups:
# FSDP2 + torch.compile example
from torch.distributed._composable.fsdp import fully_shard
# Apply FSDP2 sharding
for layer in model.layers:
fully_shard(layer)
fully_shard(model)
# Compile the model (works with FSDP2)
model = torch.compile(model)
# Training with compiled + sharded model
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
# Benefits:
# • Kernel fusion across FSDP boundaries
# • Reduced Python overhead
# • Optimized communication patterns
Context Parallel for Long Sequences
A new parallelism dimension for handling very long sequences:
┌─────────────────────────────────────────────────────────────────────────┐
│ CONTEXT PARALLELISM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE PROBLEM: │
│ ──────────── │
│ │
│ With context lengths of 128K-1M tokens: │
│ • Activation memory: O(seq_len × hidden_size × batch) │
│ • KV cache: O(seq_len × num_layers × num_heads × head_dim) │
│ │
│ Even with TP/PP/DP, single GPU can't hold activations for 1M tokens. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ CONTEXT PARALLELISM (CP): │
│ ───────────────────────── │
│ │
│ Split sequence across GPUs along the sequence dimension. │
│ │
│ Sequence: [tok_0, tok_1, ..., tok_128K] │
│ CP = 4: │
│ GPU 0: [tok_0 ... tok_32K] │
│ GPU 1: [tok_32K ... tok_64K] │
│ GPU 2: [tok_64K ... tok_96K] │
│ GPU 3: [tok_96K ... tok_128K] │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RING ATTENTION: │
│ ─────────────── │
│ │
│ For attention, each position needs to attend to all previous. │
│ Solution: Ring communication pattern for KV sharing. │
│ │
│ Step 1: Each GPU computes attention for local Q with local KV │
│ Step 2: Pass KV to next GPU in ring, receive from previous │
│ Step 3: Compute attention with received KV │
│ Step 4: Repeat until all KV seen │
│ │
│ Memory: O(seq_len/CP) per GPU │
│ Communication: Overlapped with compute │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ USAGE (Megatron-LM): │
│ │
│ torchrun ... pretrain_gpt.py \ │
│ --context-parallel-size 8 \ │
│ --seq-length 1048576 \ │
│ --tensor-model-parallel-size 8 \ │
│ --pipeline-model-parallel-size 4 │
│ │
│ 4D Parallelism: TP × PP × DP × CP │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Float8 Training
FP8 training enables significant memory and compute savings:
┌─────────────────────────────────────────────────────────────────────────┐
│ FP8 TRAINING (2024-2025) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ FP8 FORMATS: │
│ ──────────── │
│ │
│ E4M3: 4 exponent bits, 3 mantissa bits │
│ • Range: ±448, Precision: ~0.1% │
│ • Good for weights and activations │
│ │
│ E5M2: 5 exponent bits, 2 mantissa bits │
│ • Range: ±57344, Precision: ~0.5% │
│ • Good for gradients (larger dynamic range) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MEMORY SAVINGS: │
│ ─────────────── │
│ │
│ vs BF16: │
│ • Weights: 2× smaller │
│ • Activations: 2× smaller │
│ • Gradients: 2× smaller (with E5M2) │
│ │
│ 70B model in FP8: ~70GB weights (vs 140GB in BF16) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PYTORCH FP8 TRAINING: │
│ │
│ from torchao.float8 import ( │
│ convert_to_float8_training, │
│ Float8LinearConfig, │
│ ) │
│ │
│ # Convert linear layers to FP8 │
│ config = Float8LinearConfig( │
│ enable_fsdp_float8_all_gather=True, # FP8 comms │
│ ) │
│ convert_to_float8_training(model, config=config) │
│ │
│ # Apply FSDP after FP8 conversion │
│ model = FSDP(model, ...) │
│ │
│ # Training loop unchanged │
│ for batch in dataloader: │
│ loss = model(batch) │
│ loss.backward() │
│ optimizer.step() │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PRODUCTION ADOPTION: │
│ ──────────────────── │
│ │
│ DeepSeek-V3: First frontier model trained with FP8 │
│ • 671B params, 14.8T tokens │
│ • 2.79M H800 GPU hours (vs expected 5-10M for BF16) │
│ • Proved FP8 works at scale │
│ │
└─────────────────────────────────────────────────────────────────────────┘
torchao and Low-Bit Optimizations
# torchao library for quantized training
from torchao.prototype.low_bit_optim import AdamW8bit, AdamWFp8
# 8-bit optimizer (reduces optimizer memory by 2×)
optimizer = AdamW8bit(model.parameters(), lr=1e-4)
# INT8 training for even more savings
from torchao.quantization import int8_dynamic_activation_int8_weight
from torchao.quantization import quantize_
# Quantize model to INT8
quantize_(model, int8_dynamic_activation_int8_weight())
# Combined: FP8 training + 8-bit optimizer + FSDP
# Enables training 70B on 8× A100 80GB
Updated Configuration Guide
┌─────────────────────────────────────────────────────────────────────────┐
│ MODERN PARALLELISM CONFIGURATION (2025) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Model Size GPUs TP PP CP DP Precision Framework │
│ ───────────────────────────────────────────────────────────────────── │
│ 7B 8 1 1 1 8 BF16 FSDP2 │
│ 13B 8 2 1 1 4 BF16 FSDP2 + TP │
│ 30B 16 4 2 1 2 BF16 3D │
│ 70B 64 8 4 1 2 FP8 3D + FP8 │
│ 175B 256 8 8 1 4 FP8 4D │
│ 405B (Llama3) 16K 8 16 - 128 BF16 4D │
│ 671B (DSV3) 2K - - - - FP8 MoE+EP │
│ 1M+ ctx 128 8 4 8 - BF16 4D + CP │
│ │
│ NEW DIMENSIONS: │
│ • CP (Context Parallel): For 128K+ context lengths │
│ • EP (Expert Parallel): For MoE models │
│ • FP8: 2× memory reduction, proven at DeepSeek scale │
│ │
│ KEY 2025 UPDATES: │
│ • TorchTitan: ICLR 2025 reference platform for 4D parallelism │
│ • FSDP2 replaces FSDP1 for new projects (DTensor-based) │
│ • Float8 + FSDP2: 50% throughput speedup over FSDP1 bf16 │
│ - float8 all_gathers for weight communication │
│ - MXFP8 support for Blackwell GPUs (2025) │
│ • AMD torchtitan fork (Nov 2025) for ROCm optimization │
│ • torch.compile works with distributed training │
│ • FP8 is production-ready (DeepSeek-V3 proved it at 671B scale) │
│ • Context parallel enables 1M+ context (262K tokens on 8 H100s) │
│ • Ring Attention: USP combines Ulysses + Ring for best efficiency │
│ • TorchFT integration for fault tolerance at scale │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Summary
Distributed training enables training models that don't fit on single GPUs. The key techniques:
┌─────────────────────────────────────────────────────────────────────────┐
│ KEY TAKEAWAYS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ DATA PARALLELISM: │
│ • Replicate model, split data │
│ • All-reduce gradients after backward │
│ • Scales well but model must fit per GPU │
│ │
│ TENSOR PARALLELISM: │
│ • Split layers across GPUs │
│ • Communication every layer (needs NVLink) │
│ • Use within node, TP ≤ 8 │
│ │
│ PIPELINE PARALLELISM: │
│ • Split model into stages │
│ • 1F1B schedule minimizes bubble │
│ • Use across nodes │
│ │
│ ZeRO / FSDP: │
│ • Partition optimizer, gradients, parameters │
│ • Stage 1/2/3 trade communication for memory │
│ • Enables much larger models │
│ │
│ 3D PARALLELISM: │
│ • Combine TP (within node) + PP (across nodes) + DP (throughput) │
│ • Required for 70B+ models │
│ │
│ 2024-2025 INNOVATIONS: │
│ • FSDP2: Per-parameter sharding, DTensor integration │
│ • FP8 Training: 2× memory savings, proven at DeepSeek scale │
│ • Context Parallel: 4th dimension for 1M+ context │
│ • torch.compile: Works with distributed training now │
│ │
│ PRACTICAL GUIDELINES: │
│ • Start with FSDP2 for new projects (not FSDP1) │
│ • Add TP when model too large for FSDP alone │
│ • Add PP for very large models across nodes │
│ • Consider FP8 for 70B+ models (2× memory reduction) │
│ • Add CP for 128K+ context training │
│ • Always use gradient checkpointing for memory │
│ • Use bf16/fp8 mixed precision │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Frequently Asked Questions
Related Articles
Transformer Architecture: A Complete Deep Dive
A comprehensive exploration of the transformer architecture—from embedding layers through attention and feed-forward networks to the output head. Understand why decoder-only models dominate, how residual connections enable deep networks, and the engineering decisions behind GPT, Llama, and modern LLMs.
LLM Inference Optimization: From Quantization to Speculative Decoding
A comprehensive guide to optimizing LLM inference for production—covering quantization, attention optimization, batching strategies, and deployment frameworks.
Mixture of Experts: Scaling LLMs Beyond Dense Models
A comprehensive deep dive into Mixture of Experts (MoE) architecture—how models like Mixtral and GPT-4 achieve massive capacity without proportional compute costs. Understand routing mechanisms, expert specialization, load balancing, and why MoE represents the future of LLM scaling.