Skip to main content
Back to Blog

Distributed Training: How to Train 70B+ Parameter Models

A comprehensive deep dive into distributed training—how to train models that don't fit on a single GPU. Understand data parallelism, tensor parallelism, pipeline parallelism, ZeRO optimization, and the engineering behind training frontier LLMs.

3 min read
Share:

Why Distributed Training Matters

A 70B parameter model in FP16 requires 140GB just for weights—far more than any single GPU can hold. Training requires additional memory for gradients (140GB), optimizer states (280GB for Adam), and activations (variable, often 100GB+). Total: 500-700GB for training, versus 80GB on the largest GPUs.

Distributed training isn't optional for large models—it's required.

But distributing training is more complex than just "use more GPUs." Communication between GPUs can bottleneck throughput. Memory can be wasted through redundant storage. Load imbalances can leave GPUs idle. Understanding these tradeoffs is essential for efficient large-scale training.

This post covers the core distributed training techniques: data parallelism, tensor parallelism, pipeline parallelism, and ZeRO optimization. You'll understand how teams train 70B+ models and how to apply these techniques yourself.


Part I: The Memory Problem

What Consumes Memory During Training?

Before diving into solutions, let's understand what we're solving:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    TRAINING MEMORY BREAKDOWN                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  FOR A 7B PARAMETER MODEL (BF16 mixed precision):                       │
│  ────────────────────────────────────────────────                        │
│                                                                          │
│  COMPONENT              MEMORY      EXPLANATION                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  Model Parameters       14 GB       7B × 2 bytes (BF16)                 │
│                                                                          │
│  Gradients              14 GB       Same size as params (BF16)          │
│                                                                          │
│  Optimizer States       56 GB       Adam needs 2 states per param:      │
│   - FP32 params         28 GB         • m (momentum): 7B × 4 bytes     │
│   - m (momentum)        14 GB         • v (variance): 7B × 4 bytes     │
│   - v (variance)        14 GB         • FP32 copy of weights           │
│                                                                          │
│  Activations            Variable    Depends on batch size & seq length │
│   - Batch 1, 2K seq     ~8 GB       Stored for backward pass           │
│   - Batch 8, 2K seq     ~64 GB      Scales with batch size             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TOTAL FOR 7B MODEL:                                                     │
│                                                                          │
│  Parameters:      14 GB                                                 │
│  Gradients:       14 GB                                                 │
│  Optimizer:       56 GB                                                 │
│  Activations:     ~32 GB (typical)                                     │
│  ──────────────────────                                                 │
│  TOTAL:           ~116 GB                                               │
│                                                                          │
│  This exceeds a single 80GB A100!                                      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SCALING TO 70B:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  Parameters:      140 GB                                                │
│  Gradients:       140 GB                                                │
│  Optimizer:       560 GB                                                │
│  Activations:     ~300 GB                                               │
│  ──────────────────────                                                 │
│  TOTAL:           ~1.1 TB                                               │
│                                                                          │
│  Would need 14× 80GB GPUs just for memory!                             │
│  (Before considering any parallelism efficiency)                       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Memory Optimization Techniques

Before going multi-GPU, there are single-GPU optimizations:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SINGLE-GPU MEMORY OPTIMIZATIONS                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1. MIXED PRECISION TRAINING (BF16/FP16):                               │
│  ─────────────────────────────────────────                               │
│                                                                          │
│  Keep master weights in FP32, compute in BF16/FP16.                   │
│  Saves ~50% memory for params/grads, maintains accuracy.              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  2. GRADIENT CHECKPOINTING:                                              │
│  ──────────────────────────                                              │
│                                                                          │
│  Don't store all activations. Recompute them during backward.         │
│                                                                          │
│  Without checkpointing (store all):                                    │
│  Layer 1 → save act1 → Layer 2 → save act2 → ... → Layer 32          │
│  Memory: O(num_layers × batch × seq × hidden)                         │
│                                                                          │
│  With checkpointing (recompute):                                       │
│  Layer 1 → Layer 2 → ... → Layer 32 (only save checkpoints)          │
│  Backward: recompute acts from checkpoints                            │
│  Memory: O(sqrt(num_layers) × batch × seq × hidden)                  │
│                                                                          │
│  Tradeoff: ~30% more compute, ~5× less activation memory             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  3. GRADIENT ACCUMULATION:                                               │
│  ─────────────────────────                                               │
│                                                                          │
│  Can't fit batch size 32? Use batch 4, accumulate 8 steps.           │
│                                                                          │
│  for i, batch in enumerate(data):                                      │
│      loss = model(batch) / accumulation_steps                         │
│      loss.backward()                                                   │
│      if (i + 1) % accumulation_steps == 0:                            │
│          optimizer.step()                                              │
│          optimizer.zero_grad()                                         │
│                                                                          │
│  Same effective batch size, lower memory.                              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  4. CPU OFFLOADING:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  Store optimizer states in CPU RAM, move to GPU when needed.          │
│  Slower but enables larger models on single GPU.                       │
│  Used by: DeepSpeed ZeRO-Offload                                      │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part II: Data Parallelism

The Simplest Approach

Data parallelism is the most straightforward distributed training technique: replicate the model on each GPU, split the data batch across GPUs.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    DATA PARALLELISM                                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Each GPU has a COMPLETE copy of the model.                            │
│  Data batch is SPLIT across GPUs.                                      │
│  Gradients are SYNCHRONIZED after each step.                           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  VISUALIZATION:                                                          │
│  ──────────────                                                          │
│                                                                          │
│  Global batch: [sample1, sample2, sample3, sample4, sample5, sample6] │
│                                                                          │
│           ┌─────────────────────────────────────────────────┐          │
│           │             SPLIT DATA                          │          │
│           └─────────────────────────────────────────────────┘          │
│                    │              │              │                      │
│                    ▼              ▼              ▼                      │
│  ┌─────────────────────┐ ┌─────────────────┐ ┌─────────────────┐      │
│  │      GPU 0          │ │     GPU 1       │ │     GPU 2       │      │
│  ├─────────────────────┤ ├─────────────────┤ ├─────────────────┤      │
│  │ Model (full copy)   │ │ Model (copy)    │ │ Model (copy)    │      │
│  │                     │ │                 │ │                 │      │
│  │ Data: [s1, s2]      │ │ Data: [s3, s4]  │ │ Data: [s5, s6]  │      │
│  │                     │ │                 │ │                 │      │
│  │ Forward → Loss      │ │ Forward → Loss  │ │ Forward → Loss  │      │
│  │ Backward → Grads    │ │ Backward → Grads│ │ Backward → Grads│      │
│  └─────────────────────┘ └─────────────────┘ └─────────────────┘      │
│           │                      │                   │                 │
│           └──────────────────────┼───────────────────┘                 │
│                                  ▼                                     │
│                    ┌─────────────────────────┐                         │
│                    │   ALL-REDUCE GRADIENTS  │                         │
│                    │   (average gradients)   │                         │
│                    └─────────────────────────┘                         │
│                                  │                                     │
│                    ┌─────────────┼─────────────┐                       │
│                    ▼             ▼             ▼                       │
│              GPU 0 update   GPU 1 update  GPU 2 update                │
│              (identical)    (identical)   (identical)                 │
│                                                                          │
│  All GPUs end up with identical model weights.                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SCALING PROPERTIES:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  N GPUs → N× throughput (ideal)                                        │
│                                                                          │
│  In practice:                                                          │
│  • Communication overhead reduces scaling                              │
│  • Typical: 90-95% scaling efficiency on good hardware                │
│  • 8 GPUs → ~7.5× throughput                                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  LIMITATIONS:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Model must fit on EACH GPU!                                           │
│  7B model: 116GB needed → doesn't fit on 80GB GPU                     │
│  Data parallelism alone can't train 7B on A100.                       │
│                                                                          │
│  For larger models, need model parallelism techniques.                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

All-Reduce Communication

The gradient synchronization uses a collective operation called "all-reduce":

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ALL-REDUCE OPERATION                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ALL-REDUCE: Each GPU starts with local data, ends with global sum.   │
│                                                                          │
│  NAIVE APPROACH (bad):                                                  │
│  ─────────────────────                                                   │
│  Send all gradients to GPU 0, sum, broadcast back.                    │
│                                                                          │
│  Time: O(N × data_size) - doesn't scale                               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RING ALL-REDUCE (efficient):                                           │
│  ────────────────────────────                                            │
│                                                                          │
│  Arrange GPUs in a ring, pass data around incrementally.              │
│                                                                          │
│  Phase 1: Reduce-Scatter                                               │
│  GPU 0 → GPU 1 → GPU 2 → GPU 3 → GPU 0 (ring)                        │
│  Each GPU sends 1/N of its data, receives and adds 1/N from neighbor │
│  After N-1 steps: each GPU has 1/N of the total sum                  │
│                                                                          │
│  Phase 2: All-Gather                                                    │
│  Same ring pattern, but now sharing the partial sums                  │
│  After N-1 steps: each GPU has the complete sum                       │
│                                                                          │
│  Time: O(data_size) - independent of N!                               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  VISUALIZATION:                                                          │
│  ──────────────                                                          │
│                                                                          │
│  4 GPUs, gradients = [g0, g1, g2, g3]                                 │
│  Goal: each GPU gets (g0 + g1 + g2 + g3) / 4                         │
│                                                                          │
│  Initial:                                                               │
│  GPU 0: [g0]      GPU 1: [g1]      GPU 2: [g2]      GPU 3: [g3]      │
│                                                                          │
│  Reduce-Scatter (split into 4 chunks, ring-reduce each):             │
│  After: Each GPU has 1/4 of total sum                                │
│  GPU 0: [Σchunk0] GPU 1: [Σchunk1] GPU 2: [Σchunk2] GPU 3: [Σchunk3]│
│                                                                          │
│  All-Gather (share chunks around ring):                               │
│  After: Each GPU has all 4 chunks                                    │
│  GPU 0: [Σall]    GPU 1: [Σall]    GPU 2: [Σall]    GPU 3: [Σall]   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMMUNICATION VOLUME:                                                   │
│  ─────────────────────                                                   │
│                                                                          │
│  For gradient size G and N GPUs:                                       │
│  Ring all-reduce: 2 × G × (N-1)/N bytes per GPU                      │
│                                                                          │
│  As N → ∞: approaches 2G bytes (constant!)                            │
│  This is why data parallelism scales well.                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part III: Tensor Parallelism

Splitting Layers Across GPUs

When a single layer is too large, split it across GPUs:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    TENSOR PARALLELISM                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Split individual LAYERS across GPUs.                                  │
│  Each GPU holds part of each layer.                                    │
│  Requires communication within each layer.                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  LINEAR LAYER SPLITTING:                                                 │
│  ───────────────────────                                                 │
│                                                                          │
│  Consider: Y = XW where W is (d_in × d_out)                           │
│                                                                          │
│  COLUMN PARALLEL (split output dimension):                             │
│  ──────────────────────────────────────────                              │
│                                                                          │
│  W = [W1 | W2]  (split columns)                                        │
│                                                                          │
│  GPU 0: Y1 = X × W1    (X is replicated)                              │
│  GPU 1: Y2 = X × W2                                                    │
│                                                                          │
│  Result: [Y1 | Y2] = Y  (concatenate outputs)                         │
│                                                                          │
│           X (replicated)                                                │
│           │                                                             │
│     ┌─────┴─────┐                                                       │
│     │           │                                                       │
│     ▼           ▼                                                       │
│  ┌──────┐   ┌──────┐                                                   │
│  │ W1   │   │ W2   │                                                   │
│  │(GPU0)│   │(GPU1)│                                                   │
│  └──────┘   └──────┘                                                   │
│     │           │                                                       │
│     ▼           ▼                                                       │
│    Y1          Y2                                                       │
│     │           │                                                       │
│     └─────┬─────┘                                                       │
│           ▼                                                             │
│      [Y1 | Y2] = Y                                                     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ROW PARALLEL (split input dimension):                                  │
│  ─────────────────────────────────────                                   │
│                                                                          │
│  W = [W1]  (split rows)                                                │
│      [W2]                                                               │
│                                                                          │
│  X = [X1 | X2]  (split input)                                         │
│                                                                          │
│  GPU 0: Y1 = X1 × W1                                                   │
│  GPU 1: Y2 = X2 × W2                                                   │
│                                                                          │
│  Result: Y = Y1 + Y2  (sum partial results - requires all-reduce!)   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMBINING FOR TRANSFORMER LAYERS:                                       │
│  ─────────────────────────────────                                       │
│                                                                          │
│  ATTENTION:                                                             │
│  • Q, K, V projections: Column parallel (split heads)                 │
│  • Output projection: Row parallel                                     │
│                                                                          │
│  FFN:                                                                   │
│  • First linear: Column parallel (split hidden)                       │
│  • Second linear: Row parallel (combine)                              │
│                                                                          │
│  This minimizes communication: only 2 all-reduces per layer.         │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Megatron-Style Tensor Parallelism

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    MEGATRON TENSOR PARALLELISM                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  MEGATRON'S KEY INSIGHT:                                                 │
│  ───────────────────────                                                 │
│                                                                          │
│  Arrange column-parallel and row-parallel to minimize communication.  │
│                                                                          │
│  Column parallel → Activation (no comm) → Row parallel → All-reduce   │
│                                                                          │
│  Only ONE all-reduce between consecutive layers!                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ATTENTION BLOCK (2-way tensor parallel):                               │
│  ─────────────────────────────────────────                               │
│                                                                          │
│      Input X (replicated on both GPUs)                                 │
│           │                                                             │
│           ├───────────────────┐                                         │
│           ▼                   ▼                                         │
│      ┌─────────┐         ┌─────────┐                                   │
│      │ QKV     │         │ QKV     │                                   │
│      │(heads   │         │(heads   │  Column parallel:                │
│      │ 0-15)   │         │ 16-31)  │  Split attention heads           │
│      │ GPU 0   │         │ GPU 1   │                                   │
│      └─────────┘         └─────────┘                                   │
│           │                   │                                         │
│           ▼                   ▼                                         │
│      Attention           Attention     (independent per GPU)           │
│           │                   │                                         │
│           ▼                   ▼                                         │
│      ┌─────────┐         ┌─────────┐                                   │
│      │ Output  │         │ Output  │  Row parallel:                   │
│      │ Proj    │         │ Proj    │  Each produces partial result   │
│      │ (part)  │         │ (part)  │                                   │
│      └─────────┘         └─────────┘                                   │
│           │                   │                                         │
│           └─────────┬─────────┘                                         │
│                     ▼                                                   │
│               ALL-REDUCE                                                │
│             (sum partials)                                             │
│                     │                                                   │
│                     ▼                                                   │
│               Output Y                                                 │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  FFN BLOCK:                                                              │
│  ──────────                                                              │
│                                                                          │
│      Input X (replicated)                                              │
│           │                                                             │
│           ├───────────────────┐                                         │
│           ▼                   ▼                                         │
│      ┌─────────┐         ┌─────────┐                                   │
│      │ FC1     │         │ FC1     │  Column parallel:                │
│      │(half of │         │(other   │  Split intermediate dim          │
│      │ hidden) │         │ half)   │                                   │
│      └─────────┘         └─────────┘                                   │
│           │                   │                                         │
│           ▼                   ▼                                         │
│        GeLU              GeLU          (independent)                   │
│           │                   │                                         │
│           ▼                   ▼                                         │
│      ┌─────────┐         ┌─────────┐                                   │
│      │ FC2     │         │ FC2     │  Row parallel                    │
│      │ (part)  │         │ (part)  │                                   │
│      └─────────┘         └─────────┘                                   │
│           │                   │                                         │
│           └─────────┬─────────┘                                         │
│                     ▼                                                   │
│               ALL-REDUCE                                                │
│                     │                                                   │
│                     ▼                                                   │
│               Output Y                                                 │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TOTAL COMMUNICATION PER TRANSFORMER LAYER:                             │
│  ──────────────────────────────────────────                              │
│                                                                          │
│  Forward: 2 all-reduces (attention + FFN)                              │
│  Backward: 2 all-reduces                                               │
│  Total: 4 all-reduces per layer per step                              │
│                                                                          │
│  Communication volume: 4 × batch × seq × hidden × dtype                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

When to Use Tensor Parallelism

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    TENSOR PARALLELISM TRADEOFFS                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ADVANTAGES:                                                             │
│  ───────────                                                             │
│  ✓ Reduces memory per GPU (model fits)                                │
│  ✓ Low latency (all GPUs work on same batch)                         │
│  ✓ No pipeline bubbles                                                │
│                                                                          │
│  DISADVANTAGES:                                                          │
│  ──────────────                                                          │
│  ✗ High communication frequency (every layer)                         │
│  ✗ Requires fast interconnect (NVLink)                                │
│  ✗ Doesn't scale beyond ~8 GPUs efficiently                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SCALING EFFICIENCY:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  TP Degree    Communication         Efficiency (NVLink)                │
│  ─────────────────────────────────────────────────────────            │
│  2            Moderate              95-98%                             │
│  4            Higher                90-95%                             │
│  8            Very high             80-90%                             │
│  16           Extreme               60-75%                             │
│                                                                          │
│  Beyond 8-way TP, communication dominates.                            │
│  Better to combine with other parallelism types.                      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  HARDWARE REQUIREMENTS:                                                  │
│  ──────────────────────                                                  │
│                                                                          │
│  Tensor parallelism NEEDS fast interconnect:                          │
│                                                                          │
│  • NVLink: 600 GB/s (H100), 300 GB/s (A100)                          │
│    → Tensor parallelism works well within a node                      │
│                                                                          │
│  • PCIe: 64 GB/s                                                       │
│    → Too slow for tensor parallelism                                  │
│                                                                          │
│  • InfiniBand: 50-400 GB/s                                            │
│    → Usable for TP across nodes, but slower                          │
│                                                                          │
│  Rule: Tensor parallelism within node, other methods across nodes.   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part IV: Pipeline Parallelism

Splitting Layers Across Stages

Pipeline parallelism divides the model into sequential stages, each on different GPUs:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    PIPELINE PARALLELISM                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Split model into STAGES (groups of consecutive layers).               │
│  Each stage on a different GPU.                                        │
│  Data flows through stages like a pipeline.                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MODEL SPLIT (32 layers, 4 GPUs):                                       │
│  ────────────────────────────────                                        │
│                                                                          │
│  GPU 0 (Stage 0): Embedding + Layers 0-7                              │
│  GPU 1 (Stage 1): Layers 8-15                                         │
│  GPU 2 (Stage 2): Layers 16-23                                        │
│  GPU 3 (Stage 3): Layers 24-31 + LM Head                              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  NAIVE PIPELINE (Sequential - Bad):                                     │
│  ──────────────────────────────────                                      │
│                                                                          │
│  Time →                                                                │
│  GPU 0: [F0]─────────────────────────────────[B0]                     │
│  GPU 1:     [F1]─────────────────────────[B1]                         │
│  GPU 2:         [F2]─────────────────[B2]                             │
│  GPU 3:             [F3]────────[B3]                                  │
│                         │      │                                       │
│                         Loss   Backward starts                         │
│                                                                          │
│  F = Forward, B = Backward                                             │
│                                                                          │
│  Problem: HUGE "bubble" where GPUs sit idle!                          │
│  Only 1 GPU active at a time. 75% idle with 4 GPUs!                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MICRO-BATCHING (GPipe):                                                │
│  ───────────────────────                                                 │
│                                                                          │
│  Split batch into micro-batches, pipeline them:                       │
│                                                                          │
│  Batch: [B1, B2, B3, B4] (4 micro-batches)                           │
│                                                                          │
│  Time →                                                                │
│  GPU 0: [F0₁][F0₂][F0₃][F0₄]                [B0₄][B0₃][B0₂][B0₁]     │
│  GPU 1:     [F1₁][F1₂][F1₃][F1₄]        [B1₄][B1₃][B1₂][B1₁]        │
│  GPU 2:         [F2₁][F2₂][F2₃][F2₄][B2₄][B2₃][B2₂][B2₁]           │
│  GPU 3:             [F3₁][F3₂][F3₃][F3₄][B3₁][B3₂][B3₃][B3₄]       │
│                                                                          │
│  Much better! GPUs stay busier.                                       │
│  But still has bubble at start and end.                               │
│                                                                          │
│  Bubble fraction ≈ (num_stages - 1) / num_micro_batches               │
│  With 4 stages, 16 micro-batches: bubble = 3/16 = 19%                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

1F1B Schedule

Modern pipeline parallelism uses 1F1B (one forward, one backward) scheduling:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    1F1B PIPELINE SCHEDULE                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1F1B: Interleave forward and backward passes.                         │
│  Reduces memory by not storing all activations simultaneously.        │
│                                                                          │
│  Time →                                                                │
│  GPU 0: [F₁][F₂][F₃][F₄][B₁][B₂][B₃][B₄]                            │
│  GPU 1:    [F₁][F₂][F₃][F₄|B₁][B₂][B₃][B₄]                          │
│  GPU 2:       [F₁][F₂][F₃|B₁][F₄|B₂][B₃][B₄]                        │
│  GPU 3:          [F₁][F₂|B₁][F₃|B₂][F₄|B₃][B₄]                      │
│                                                                          │
│  Key insight: Start backward while still doing forward!               │
│  Once GPU 3 finishes F₁, it can do B₁ immediately.                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MEMORY COMPARISON:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  GPipe (all F then all B):                                             │
│  Must store activations for ALL micro-batches during forward.         │
│  Memory: O(num_micro_batches × activations_per_micro)                 │
│                                                                          │
│  1F1B:                                                                  │
│  Store activations for at most num_stages micro-batches.             │
│  Memory: O(num_stages × activations_per_micro)                        │
│                                                                          │
│  With 16 micro-batches and 4 stages:                                  │
│  GPipe: 16× activation memory                                         │
│  1F1B: 4× activation memory (4× reduction!)                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  INTERLEAVED STAGES (Further optimization):                             │
│  ───────────────────────────────────────────                             │
│                                                                          │
│  Instead of contiguous layers per GPU, interleave:                    │
│                                                                          │
│  Standard (4 GPUs, 32 layers):                                        │
│  GPU 0: Layers 0-7                                                    │
│  GPU 1: Layers 8-15                                                   │
│  GPU 2: Layers 16-23                                                  │
│  GPU 3: Layers 24-31                                                  │
│                                                                          │
│  Interleaved (virtual stages):                                         │
│  GPU 0: Layers 0-3, 16-19                                            │
│  GPU 1: Layers 4-7, 20-23                                            │
│  GPU 2: Layers 8-11, 24-27                                           │
│  GPU 3: Layers 12-15, 28-31                                          │
│                                                                          │
│  More stages = more pipeline opportunities = smaller bubble!         │
│  Tradeoff: More communication (cross-GPU between layer groups).      │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part V: ZeRO - Zero Redundancy Optimizer

Eliminating Memory Redundancy

In data parallelism, EVERY GPU stores a complete copy of model, gradients, and optimizer states. ZeRO partitions these across GPUs:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ZeRO OPTIMIZATION STAGES                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  DATA PARALLEL REDUNDANCY:                                               │
│  ─────────────────────────                                               │
│                                                                          │
│  With 4 GPUs, traditional DP stores on EACH GPU:                       │
│  • Full model parameters                                               │
│  • Full gradients                                                      │
│  • Full optimizer states                                               │
│                                                                          │
│  4× redundant storage! Massive waste.                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ZeRO STAGE 1 (Optimizer State Partitioning):                          │
│  ─────────────────────────────────────────────                           │
│                                                                          │
│  Partition optimizer states across GPUs.                               │
│  Each GPU stores 1/N of optimizer states.                             │
│                                                                          │
│  Standard DP (4 GPUs, 7B model):                                       │
│  Each GPU: 56 GB optimizer states                                     │
│                                                                          │
│  ZeRO-1:                                                                │
│  Each GPU: 14 GB optimizer states (1/4)                               │
│                                                                          │
│  Memory saved: 75% of optimizer memory                                 │
│                                                                          │
│  Communication: After gradient all-reduce, each GPU updates its       │
│  partition, then all-gather to sync parameters.                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ZeRO STAGE 2 (+ Gradient Partitioning):                               │
│  ────────────────────────────────────────                                │
│                                                                          │
│  Also partition gradients across GPUs.                                 │
│  Each GPU only stores gradients for params it will update.           │
│                                                                          │
│  ZeRO-2:                                                                │
│  Each GPU: 14 GB optimizer + 3.5 GB gradients (vs 14 GB)             │
│                                                                          │
│  Communication: Reduce-scatter gradients (not all-reduce).            │
│  Each GPU gets only the gradients it needs.                           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ZeRO STAGE 3 (+ Parameter Partitioning):                              │
│  ─────────────────────────────────────────                               │
│                                                                          │
│  Partition everything: optimizer, gradients, AND parameters.          │
│  Each GPU stores only 1/N of the model.                               │
│                                                                          │
│  ZeRO-3:                                                                │
│  Each GPU: 14 GB optimizer + 3.5 GB gradients + 3.5 GB params        │
│  (vs 14 + 14 + 56 = 84 GB in standard DP)                            │
│                                                                          │
│  Communication: All-gather params before forward/backward.            │
│  Higher communication, but enables training much larger models.       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MEMORY COMPARISON (7B model, 4 GPUs):                                  │
│                                                                          │
│  Method      Params   Grads    Optimizer   Total/GPU   Reduction      │
│  ─────────────────────────────────────────────────────────────────    │
│  Standard DP 14 GB    14 GB    56 GB       84 GB       1×             │
│  ZeRO-1      14 GB    14 GB    14 GB       42 GB       2×             │
│  ZeRO-2      14 GB    3.5 GB   14 GB       31.5 GB     2.7×           │
│  ZeRO-3      3.5 GB   3.5 GB   14 GB       21 GB       4×             │
│                                                                          │
│  ZeRO-3 enables 4× larger models with same GPU memory!               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

ZeRO-3 Deep Dive

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ZeRO-3 PARAMETER PARTITIONING                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  HOW PARAMETER PARTITIONING WORKS:                                       │
│  ─────────────────────────────────                                       │
│                                                                          │
│  Model has layers: [L0, L1, L2, L3, L4, L5, L6, L7]                    │
│  4 GPUs                                                                 │
│                                                                          │
│  Partitioning:                                                          │
│  GPU 0 stores: L0, L4                                                  │
│  GPU 1 stores: L1, L5                                                  │
│  GPU 2 stores: L2, L6                                                  │
│  GPU 3 stores: L3, L7                                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  FORWARD PASS:                                                           │
│  ─────────────                                                           │
│                                                                          │
│  To compute L0(x), ALL GPUs need L0's parameters.                      │
│                                                                          │
│  Step 1: GPU 0 broadcasts L0 params to all GPUs (all-gather)          │
│  Step 2: All GPUs compute L0(x) in parallel                           │
│  Step 3: Discard L0 params on GPUs 1,2,3 (keep on GPU 0)             │
│  Step 4: GPU 1 broadcasts L1 params                                    │
│  Step 5: All GPUs compute L1(...)                                     │
│  ... repeat for each layer                                             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BACKWARD PASS:                                                          │
│  ─────────────                                                           │
│                                                                          │
│  Same pattern in reverse:                                              │
│  1. All-gather params for layer                                        │
│  2. Compute gradients                                                  │
│  3. Reduce-scatter gradients (each GPU gets its partition)            │
│  4. Update local optimizer states                                      │
│  5. Update local parameters                                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMMUNICATION OVERHEAD:                                                 │
│  ───────────────────────                                                 │
│                                                                          │
│  ZeRO-3 has 3× communication vs standard DP:                           │
│                                                                          │
│  Standard DP: 2× model size (all-reduce gradients)                    │
│                                                                          │
│  ZeRO-3:                                                                │
│  • 1× all-gather params (forward)                                     │
│  • 1× all-gather params (backward)                                    │
│  • 1× reduce-scatter gradients                                        │
│  • 1× all-gather updated params                                       │
│  Total: 4× model size                                                  │
│                                                                          │
│  But: Memory savings often more valuable than bandwidth.              │
│  Can train larger models → better results.                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VI: FSDP - Fully Sharded Data Parallelism

PyTorch's ZeRO Implementation

FSDP is PyTorch's native implementation of ZeRO-3-like sharding:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    FSDP (Fully Sharded Data Parallel)                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  FSDP = PyTorch's native ZeRO-3                                         │
│  ────────────────────────────                                            │
│                                                                          │
│  Key features:                                                          │
│  • Shard params, gradients, optimizer states                          │
│  • Automatic parameter gathering during forward/backward              │
│  • Works with PyTorch ecosystem                                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BASIC USAGE:                                                            │
│  ────────────                                                            │
│                                                                          │
│  from torch.distributed.fsdp import (                                   │
│      FullyShardedDataParallel as FSDP,                                 │
│      ShardingStrategy,                                                  │
│  )                                                                       │
│                                                                          │
│  # Wrap model with FSDP                                                │
│  model = FSDP(                                                          │
│      model,                                                             │
│      sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3        │
│      # or SHARD_GRAD_OP for ZeRO-2                                    │
│      # or NO_SHARD for standard DDP                                   │
│  )                                                                       │
│                                                                          │
│  # Training loop unchanged                                              │
│  for batch in dataloader:                                               │
│      loss = model(batch)                                               │
│      loss.backward()                                                   │
│      optimizer.step()                                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SHARDING STRATEGIES:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  FULL_SHARD (ZeRO-3):                                                  │
│  • Shard params, grads, optimizer                                     │
│  • Maximum memory efficiency                                           │
│  • Highest communication                                               │
│                                                                          │
│  SHARD_GRAD_OP (ZeRO-2):                                               │
│  • Shard grads, optimizer                                             │
│  • Params replicated                                                  │
│  • Less communication                                                 │
│                                                                          │
│  NO_SHARD (DDP):                                                        │
│  • Nothing sharded                                                     │
│  • Standard data parallel                                             │
│  • Lowest communication                                               │
│                                                                          │
│  HYBRID_SHARD:                                                          │
│  • Full shard within node                                             │
│  • Replicate across nodes                                             │
│  • Balance memory and communication                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  AUTO WRAPPING:                                                          │
│  ──────────────                                                          │
│                                                                          │
│  FSDP wraps at module granularity. Control with policies:             │
│                                                                          │
│  from torch.distributed.fsdp.wrap import (                             │
│      transformer_auto_wrap_policy,                                     │
│  )                                                                       │
│                                                                          │
│  # Wrap each transformer block separately                             │
│  policy = transformer_auto_wrap_policy(                                │
│      transformer_layer_cls={TransformerBlock},                        │
│  )                                                                       │
│                                                                          │
│  model = FSDP(model, auto_wrap_policy=policy)                         │
│                                                                          │
│  This ensures each block's params are sharded as a unit.             │
│  Communication happens between blocks, not within.                    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VII: Combining Parallelism Strategies

3D Parallelism

Real large-scale training combines multiple parallelism types:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    3D PARALLELISM                                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Combine:                                                               │
│  • Data Parallelism (DP): Replicate model, split data                 │
│  • Tensor Parallelism (TP): Split layers within device group          │
│  • Pipeline Parallelism (PP): Split model stages across groups        │
│                                                                          │
│  Each addresses different constraints:                                 │
│  • TP: Reduces per-GPU memory, high bandwidth needed                  │
│  • PP: Reduces per-GPU memory, tolerates lower bandwidth              │
│  • DP: Increases throughput, requires model to fit                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  EXAMPLE: Training 70B on 64 GPUs (8 nodes × 8 GPUs):                  │
│  ─────────────────────────────────────────────────────                   │
│                                                                          │
│  Configuration:                                                         │
│  • TP = 8 (within each node, using NVLink)                            │
│  • PP = 4 (across 4 node groups)                                      │
│  • DP = 2 (2 replicas of the pipeline)                                │
│                                                                          │
│  Total: 8 × 4 × 2 = 64 GPUs                                           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  VISUAL:                                                                 │
│                                                                          │
│                      Data Parallel Replica 1                           │
│         ┌──────────────────────────────────────────────┐               │
│         │                                              │               │
│  PP     │  TP Group 0    TP Group 1    ...           │               │
│  Stage  │  (8 GPUs)      (8 GPUs)                    │               │
│    0    │  ┌────────┐    ┌────────┐                  │               │
│         │  │Node 0  │    │Node 4  │                  │               │
│         │  │8×A100  │    │8×A100  │                  │               │
│    │    │  └────────┘    └────────┘                  │               │
│    │    │       │              │                      │               │
│    │    │       ▼              ▼                      │               │
│    1    │  ┌────────┐    ┌────────┐                  │               │
│         │  │Node 1  │    │Node 5  │                  │               │
│    │    │  └────────┘    └────────┘                  │               │
│    │    │       │              │                      │               │
│    │    │       ▼              ▼                      │               │
│    2    │  ┌────────┐    ┌────────┐                  │               │
│         │  │Node 2  │    │Node 6  │                  │               │
│    │    │  └────────┘    └────────┘                  │               │
│    │    │       │              │                      │               │
│    │    │       ▼              ▼                      │               │
│    3    │  ┌────────┐    ┌────────┐                  │               │
│         │  │Node 3  │    │Node 7  │                  │               │
│         │  └────────┘    └────────┘                  │               │
│         │                                              │               │
│         └──────────────────────────────────────────────┘               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMMUNICATION PATTERNS:                                                 │
│  ───────────────────────                                                 │
│                                                                          │
│  Tensor Parallelism (within node):                                     │
│  • All-reduce every layer                                              │
│  • Requires NVLink (600 GB/s on H100)                                 │
│                                                                          │
│  Pipeline Parallelism (across nodes):                                  │
│  • Point-to-point between stages                                      │
│  • Uses InfiniBand (400 Gb/s)                                         │
│  • Lower bandwidth OK due to less frequent comm                       │
│                                                                          │
│  Data Parallelism (across replicas):                                   │
│  • All-reduce gradients after backward                                │
│  • Uses InfiniBand                                                     │
│  • Overlaps with compute                                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Choosing Parallelism Configuration

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    PARALLELISM CONFIGURATION GUIDE                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  DECISION FRAMEWORK:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  1. Does model fit on 1 GPU?                                           │
│     YES → Use Data Parallel only (or FSDP for memory efficiency)      │
│     NO  → Continue to step 2                                          │
│                                                                          │
│  2. Does model fit with TP within node?                                │
│     Try TP = 2, 4, or 8 (NVLink connected GPUs)                       │
│     YES → Use TP + DP                                                 │
│     NO  → Continue to step 3                                          │
│                                                                          │
│  3. Add Pipeline Parallelism                                           │
│     PP = model_size / (memory_per_gpu × TP)                           │
│     Use 1F1B schedule with micro-batching                             │
│                                                                          │
│  4. Fill remaining GPUs with Data Parallel                            │
│     DP = total_gpus / (TP × PP)                                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  EXAMPLE CONFIGURATIONS:                                                 │
│                                                                          │
│  Model Size    GPUs    TP    PP    DP    Framework                     │
│  ─────────────────────────────────────────────────────────────────    │
│  7B            8       1     1     8     DDP or FSDP                   │
│  13B           8       2     1     4     TP + DP                       │
│  30B           16      4     2     2     TP + PP + DP                  │
│  70B           64      8     4     2     Full 3D                       │
│  175B          512     8     16    4     Full 3D                       │
│  500B+         1000+   8     32+   4+    Full 3D + Expert Parallel    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RULES OF THUMB:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  • TP within node only (NVLink required)                              │
│  • TP ≤ 8 (diminishing returns beyond)                                │
│  • PP micro-batches ≥ 4× PP stages (reduce bubble)                   │
│  • DP as large as possible (best scaling)                            │
│  • Memory per GPU: model/(TP×PP) + activations + optimizer            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMMON MISTAKES:                                                        │
│  ───────────────                                                         │
│                                                                          │
│  ✗ Using TP across nodes (too slow without NVLink)                    │
│  ✗ Too few micro-batches for PP (large bubbles)                       │
│  ✗ Not enabling gradient checkpointing (OOM on activations)           │
│  ✗ Wrong TP degree (not matching attention head count)                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VIII: Practical Implementation

Using DeepSpeed

DeepSpeed is the most popular library for large-scale training:

Python
# DeepSpeed configuration for 70B model training

deepspeed_config = {
    # ZeRO Stage 3 for maximum memory efficiency
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",  # Offload optimizer to CPU
            "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",  # Offload params to CPU (optional)
            "pin_memory": True
        },
        "overlap_comm": True,  # Overlap communication with compute
        "contiguous_gradients": True,
        "reduce_bucket_size": 5e8,
        "stage3_prefetch_bucket_size": 5e8,
        "stage3_param_persistence_threshold": 1e6,
    },

    # Mixed precision training
    "bf16": {
        "enabled": True
    },

    # Gradient accumulation
    "gradient_accumulation_steps": 8,

    # Optimizer
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 1e-5,
            "betas": [0.9, 0.95],
            "weight_decay": 0.1
        }
    },

    # Learning rate scheduler
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 1e-5,
            "warmup_num_steps": 100,
            "total_num_steps": 10000
        }
    },

    # Training batch size
    "train_micro_batch_size_per_gpu": 1,
    "gradient_clipping": 1.0,
}

# Launch training
# deepspeed --num_gpus=8 train.py --deepspeed_config config.json

Using Megatron-LM

For tensor and pipeline parallelism:

Bash
# Megatron-LM training launch command

GPUS_PER_NODE=8
NNODES=8
NODE_RANK=$SLURM_PROCID
MASTER_ADDR=$SLURM_SUBMIT_HOST
MASTER_PORT=6000

DISTRIBUTED_ARGS="
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    --node_rank $NODE_RANK \
    --master_addr $MASTER_ADDR \
    --master_port $MASTER_PORT
"

# Model parallelism configuration
TENSOR_PARALLEL=8      # Within node (NVLink)
PIPELINE_PARALLEL=4    # Across nodes
# Data parallel is implicit: 64 / (8 * 4) = 2

MODEL_ARGS="
    --num-layers 80 \
    --hidden-size 8192 \
    --num-attention-heads 64 \
    --seq-length 4096 \
    --max-position-embeddings 4096 \
    --tensor-model-parallel-size $TENSOR_PARALLEL \
    --pipeline-model-parallel-size $PIPELINE_PARALLEL \
"

TRAINING_ARGS="
    --micro-batch-size 1 \
    --global-batch-size 1024 \
    --train-iters 100000 \
    --lr 1.5e-4 \
    --min-lr 1.5e-5 \
    --lr-decay-iters 100000 \
    --lr-warmup-iters 2000 \
    --bf16 \
    --use-flash-attn \
    --recompute-activations \
"

torchrun $DISTRIBUTED_ARGS \
    pretrain_gpt.py \
    $MODEL_ARGS \
    $TRAINING_ARGS

Part IX: Recent Innovations (2024-2025)

FSDP2 - Next-Generation Sharding

PyTorch introduced FSDP2 in 2024 with significant improvements over the original FSDP:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    FSDP2 IMPROVEMENTS                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KEY CHANGES FROM FSDP1:                                                 │
│  ───────────────────────                                                 │
│                                                                          │
│  1. PER-PARAMETER SHARDING (not per-module):                            │
│     FSDP1: Wraps entire modules, all params sharded together            │
│     FSDP2: Each parameter independently sharded                         │
│                                                                          │
│     Benefits:                                                            │
│     • Finer-grained memory control                                      │
│     • Better support for heterogeneous param sizes                      │
│     • Easier mixed-precision per-parameter                              │
│                                                                          │
│  2. DTensor INTEGRATION:                                                 │
│     Built on DTensor (Distributed Tensor) abstraction                   │
│     • Unified sharding specification language                           │
│     • Cleaner composition with tensor parallelism                       │
│     • Better debugging and visualization                                │
│                                                                          │
│  3. EXPLICIT PREFETCH CONTROL:                                          │
│     FSDP1: Fixed prefetch behavior                                      │
│     FSDP2: User-controlled prefetch policies                            │
│                                                                          │
│     model.set_prefetch_policy(PrefetchPolicy.BACKWARD)                  │
│                                                                          │
│  4. COMPOSABLE API:                                                      │
│     Designed to compose cleanly with:                                   │
│     • Tensor Parallel (tp.parallelize_module)                          │
│     • Pipeline Parallel                                                  │
│     • Activation checkpointing                                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  FSDP2 BASIC USAGE:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  from torch.distributed._composable.fsdp import fully_shard            │
│  from torch.distributed.device_mesh import DeviceMesh                   │
│                                                                          │
│  # Create device mesh for multi-dimensional parallelism                │
│  mesh = DeviceMesh("cuda", torch.arange(8).reshape(2, 4))              │
│  # First dim: data parallel (2), second dim: tensor parallel (4)       │
│                                                                          │
│  # Shard model parameters                                               │
│  for layer in model.transformer_blocks:                                 │
│      fully_shard(layer, mesh=mesh["dp"])  # Shard on DP dimension     │
│  fully_shard(model, mesh=mesh["dp"])                                   │
│                                                                          │
│  # Training loop                                                        │
│  for batch in dataloader:                                               │
│      loss = model(batch)                                                │
│      loss.backward()                                                    │
│      optimizer.step()                                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPOSING WITH TENSOR PARALLEL:                                        │
│  ────────────────────────────────                                        │
│                                                                          │
│  from torch.distributed.tensor.parallel import parallelize_module      │
│                                                                          │
│  # 2D parallelism: TP + FSDP                                           │
│  mesh_2d = DeviceMesh("cuda", torch.arange(8).reshape(2, 4))           │
│  tp_mesh = mesh_2d["tp"]  # 4-way tensor parallel                      │
│  dp_mesh = mesh_2d["dp"]  # 2-way data parallel (FSDP)                 │
│                                                                          │
│  # First apply tensor parallelism                                       │
│  parallelize_module(model, tp_mesh, tp_plan)                           │
│                                                                          │
│  # Then apply FSDP (shards the TP-split params)                        │
│  for layer in model.layers:                                             │
│      fully_shard(layer, mesh=dp_mesh)                                  │
│  fully_shard(model, mesh=dp_mesh)                                      │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

torch.compile Integration

FSDP2 works seamlessly with torch.compile for additional speedups:

Python
# FSDP2 + torch.compile example

from torch.distributed._composable.fsdp import fully_shard

# Apply FSDP2 sharding
for layer in model.layers:
    fully_shard(layer)
fully_shard(model)

# Compile the model (works with FSDP2)
model = torch.compile(model)

# Training with compiled + sharded model
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()

# Benefits:
# • Kernel fusion across FSDP boundaries
# • Reduced Python overhead
# • Optimized communication patterns

Context Parallel for Long Sequences

A new parallelism dimension for handling very long sequences:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    CONTEXT PARALLELISM                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE PROBLEM:                                                            │
│  ────────────                                                            │
│                                                                          │
│  With context lengths of 128K-1M tokens:                                │
│  • Activation memory: O(seq_len × hidden_size × batch)                 │
│  • KV cache: O(seq_len × num_layers × num_heads × head_dim)           │
│                                                                          │
│  Even with TP/PP/DP, single GPU can't hold activations for 1M tokens.  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  CONTEXT PARALLELISM (CP):                                              │
│  ─────────────────────────                                               │
│                                                                          │
│  Split sequence across GPUs along the sequence dimension.               │
│                                                                          │
│  Sequence: [tok_0, tok_1, ..., tok_128K]                               │
│  CP = 4:                                                                 │
│  GPU 0: [tok_0 ... tok_32K]                                            │
│  GPU 1: [tok_32K ... tok_64K]                                          │
│  GPU 2: [tok_64K ... tok_96K]                                          │
│  GPU 3: [tok_96K ... tok_128K]                                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RING ATTENTION:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  For attention, each position needs to attend to all previous.         │
│  Solution: Ring communication pattern for KV sharing.                   │
│                                                                          │
│  Step 1: Each GPU computes attention for local Q with local KV         │
│  Step 2: Pass KV to next GPU in ring, receive from previous           │
│  Step 3: Compute attention with received KV                            │
│  Step 4: Repeat until all KV seen                                      │
│                                                                          │
│  Memory: O(seq_len/CP) per GPU                                         │
│  Communication: Overlapped with compute                                 │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  USAGE (Megatron-LM):                                                   │
│                                                                          │
│  torchrun ... pretrain_gpt.py \                                        │
│      --context-parallel-size 8 \                                        │
│      --seq-length 1048576 \                                             │
│      --tensor-model-parallel-size 8 \                                  │
│      --pipeline-model-parallel-size 4                                  │
│                                                                          │
│  4D Parallelism: TP × PP × DP × CP                                     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Float8 Training

FP8 training enables significant memory and compute savings:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    FP8 TRAINING (2024-2025)                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  FP8 FORMATS:                                                            │
│  ────────────                                                            │
│                                                                          │
│  E4M3: 4 exponent bits, 3 mantissa bits                                │
│  • Range: ±448, Precision: ~0.1%                                       │
│  • Good for weights and activations                                    │
│                                                                          │
│  E5M2: 5 exponent bits, 2 mantissa bits                                │
│  • Range: ±57344, Precision: ~0.5%                                     │
│  • Good for gradients (larger dynamic range)                           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MEMORY SAVINGS:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  vs BF16:                                                               │
│  • Weights: 2× smaller                                                 │
│  • Activations: 2× smaller                                             │
│  • Gradients: 2× smaller (with E5M2)                                   │
│                                                                          │
│  70B model in FP8: ~70GB weights (vs 140GB in BF16)                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  PYTORCH FP8 TRAINING:                                                  │
│                                                                          │
│  from torchao.float8 import (                                          │
│      convert_to_float8_training,                                        │
│      Float8LinearConfig,                                                │
│  )                                                                       │
│                                                                          │
│  # Convert linear layers to FP8                                        │
│  config = Float8LinearConfig(                                           │
│      enable_fsdp_float8_all_gather=True,  # FP8 comms                  │
│  )                                                                       │
│  convert_to_float8_training(model, config=config)                      │
│                                                                          │
│  # Apply FSDP after FP8 conversion                                     │
│  model = FSDP(model, ...)                                              │
│                                                                          │
│  # Training loop unchanged                                              │
│  for batch in dataloader:                                               │
│      loss = model(batch)                                                │
│      loss.backward()                                                    │
│      optimizer.step()                                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  PRODUCTION ADOPTION:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  DeepSeek-V3: First frontier model trained with FP8                    │
│  • 671B params, 14.8T tokens                                           │
│  • 2.79M H800 GPU hours (vs expected 5-10M for BF16)                  │
│  • Proved FP8 works at scale                                           │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

torchao and Low-Bit Optimizations

Python
# torchao library for quantized training

from torchao.prototype.low_bit_optim import AdamW8bit, AdamWFp8

# 8-bit optimizer (reduces optimizer memory by 2×)
optimizer = AdamW8bit(model.parameters(), lr=1e-4)

# INT8 training for even more savings
from torchao.quantization import int8_dynamic_activation_int8_weight
from torchao.quantization import quantize_

# Quantize model to INT8
quantize_(model, int8_dynamic_activation_int8_weight())

# Combined: FP8 training + 8-bit optimizer + FSDP
# Enables training 70B on 8× A100 80GB

Updated Configuration Guide

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    MODERN PARALLELISM CONFIGURATION (2025)               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Model Size    GPUs    TP    PP    CP    DP    Precision   Framework   │
│  ─────────────────────────────────────────────────────────────────────  │
│  7B            8       1     1     1     8     BF16        FSDP2       │
│  13B           8       2     1     1     4     BF16        FSDP2 + TP  │
│  30B           16      4     2     1     2     BF16        3D          │
│  70B           64      8     4     1     2     FP8         3D + FP8    │
│  175B          256     8     8     1     4     FP8         4D          │
│  405B (Llama3) 16K     8     16    -     128   BF16        4D          │
│  671B (DSV3)   2K      -     -     -     -     FP8         MoE+EP      │
│  1M+ ctx       128     8     4     8     -     BF16        4D + CP     │
│                                                                          │
│  NEW DIMENSIONS:                                                         │
│  • CP (Context Parallel): For 128K+ context lengths                    │
│  • EP (Expert Parallel): For MoE models                                │
│  • FP8: 2× memory reduction, proven at DeepSeek scale                  │
│                                                                          │
│  KEY 2025 UPDATES:                                                       │
│  • TorchTitan: ICLR 2025 reference platform for 4D parallelism         │
│  • FSDP2 replaces FSDP1 for new projects (DTensor-based)              │
│  • Float8 + FSDP2: 50% throughput speedup over FSDP1 bf16             │
│    - float8 all_gathers for weight communication                      │
│    - MXFP8 support for Blackwell GPUs (2025)                         │
│  • AMD torchtitan fork (Nov 2025) for ROCm optimization              │
│  • torch.compile works with distributed training                       │
│  • FP8 is production-ready (DeepSeek-V3 proved it at 671B scale)      │
│  • Context parallel enables 1M+ context (262K tokens on 8 H100s)      │
│  • Ring Attention: USP combines Ulysses + Ring for best efficiency    │
│  • TorchFT integration for fault tolerance at scale                   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Summary

Distributed training enables training models that don't fit on single GPUs. The key techniques:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    KEY TAKEAWAYS                                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  DATA PARALLELISM:                                                       │
│  • Replicate model, split data                                         │
│  • All-reduce gradients after backward                                 │
│  • Scales well but model must fit per GPU                             │
│                                                                          │
│  TENSOR PARALLELISM:                                                     │
│  • Split layers across GPUs                                            │
│  • Communication every layer (needs NVLink)                           │
│  • Use within node, TP ≤ 8                                            │
│                                                                          │
│  PIPELINE PARALLELISM:                                                   │
│  • Split model into stages                                             │
│  • 1F1B schedule minimizes bubble                                     │
│  • Use across nodes                                                    │
│                                                                          │
│  ZeRO / FSDP:                                                           │
│  • Partition optimizer, gradients, parameters                         │
│  • Stage 1/2/3 trade communication for memory                        │
│  • Enables much larger models                                         │
│                                                                          │
│  3D PARALLELISM:                                                         │
│  • Combine TP (within node) + PP (across nodes) + DP (throughput)    │
│  • Required for 70B+ models                                           │
│                                                                          │
│  2024-2025 INNOVATIONS:                                                 │
│  • FSDP2: Per-parameter sharding, DTensor integration                │
│  • FP8 Training: 2× memory savings, proven at DeepSeek scale        │
│  • Context Parallel: 4th dimension for 1M+ context                   │
│  • torch.compile: Works with distributed training now                │
│                                                                          │
│  PRACTICAL GUIDELINES:                                                   │
│  • Start with FSDP2 for new projects (not FSDP1)                     │
│  • Add TP when model too large for FSDP alone                        │
│  • Add PP for very large models across nodes                         │
│  • Consider FP8 for 70B+ models (2× memory reduction)               │
│  • Add CP for 128K+ context training                                 │
│  • Always use gradient checkpointing for memory                      │
│  • Use bf16/fp8 mixed precision                                       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles