Module 06 · Section 6.6

Distributed Training at Scale

Parallelism strategies, communication primitives, and mixed precision for training models across hundreds of GPUs

Training a large model on one GPU is like reading the entire internet through a keyhole. Distributed training lets you knock down the wall, provided all the GPUs agree on which wall and when.

An Overly Synchronized Process
★ Big Picture

No single GPU can train a modern LLM. A 70B parameter model requires over 140 GB just for its parameters in FP16, far exceeding the memory of any single accelerator. Training such models demands distributing computation across dozens to thousands of GPUs, coordinating their work through high-speed interconnects. This section covers the four fundamental parallelism strategies (data, tensor, pipeline, and expert parallelism), the communication primitives that enable them, mixed-precision training (including FP8), and the memory optimization techniques that make large-scale training feasible.

⚙ Prerequisites

This section assumes familiarity with PyTorch tensor operations from Section 0.2 and the transformer architecture from Module 04. Understanding of matrix multiplication is essential for the tensor parallelism discussion. The optimizer memory analysis from Section 6.5 motivates why distributed training is necessary.

1. Communication Primitives

Distributed training relies on collective communication operations to synchronize data between GPUs. Understanding these primitives is essential for reasoning about the communication overhead of different parallelism strategies.

Primitive Input Output Use Case
All-Reduce Each GPU has a tensor All GPUs have the sum Gradient synchronization (DDP)
All-Gather Each GPU has a shard All GPUs have the full tensor Parameter reconstruction (FSDP)
Reduce-Scatter Each GPU has a tensor Each GPU has a shard of the sum Gradient sharding (FSDP)
Broadcast One GPU has a tensor All GPUs have it Weight initialization

These operations are implemented efficiently using ring or tree topologies. In a ring all-reduce with P GPUs, each GPU sends and receives 2(P−1)/P times the tensor size, giving near-optimal bandwidth utilization regardless of the number of GPUs. The NCCL library (NVIDIA Collective Communications Library) provides highly optimized implementations for NVIDIA GPUs.

2. Data Parallelism (DDP)

Distributed Data Parallelism is the simplest and most widely used form of parallelism. Each GPU holds a complete copy of the model and processes a different subset of the training data. After each forward-backward pass, gradients are synchronized across all GPUs using all-reduce, ensuring that all copies perform identical parameter updates.

Data Parallelism (DDP) GPU 0 Full Model Data Shard 0 Gradients GPU 1 Full Model Data Shard 1 Gradients GPU 2 Full Model Data Shard 2 Gradients All-Reduce Gradients
Figure 6.6.1: In DDP, each GPU holds a full model copy and processes different data. Gradients are synchronized via all-reduce after each backward pass.
# DDP training with PyTorch
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def train_ddp(rank, world_size, model_class):
    setup_ddp(rank, world_size)

    # Each GPU gets a full model copy
    model = model_class().to(rank)
    model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    for batch in dataloader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()     # DDP auto-syncs gradients via all-reduce
        optimizer.step()

    dist.destroy_process_group()

3. Fully Sharded Data Parallelism (FSDP) and ZeRO

DDP's limitation is that every GPU must hold a complete copy of the model, gradients, and optimizer states. For a 7B model with AdamW, that is ~112 GB per GPU. FSDP (and the equivalent DeepSpeed ZeRO) resolves this by sharding these tensors across GPUs so each GPU stores only a fraction.

ZeRO Optimization Stages

⚡ Key Insight

FSDP Stage 3 trades communication for memory. Each layer's forward pass requires an all-gather to reconstruct the full parameters, and each backward pass requires a reduce-scatter of the gradients. This means each parameter is communicated 3 times per training step (gather for forward, gather for backward, reduce-scatter for gradient). The communication overhead is significant but acceptable when the alternative is not being able to train the model at all.

# FSDP training with PyTorch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy

# Mixed precision policy for FSDP
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

# Wrap model with FSDP (full sharding = ZeRO Stage 3)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mp_policy,
    device_id=torch.cuda.current_device(),
)

# Training loop is identical to standard PyTorch
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()
    optimizer.step()

4. Tensor Parallelism

Tensor parallelism splits individual layers across GPUs. For a linear layer Y = XW, the weight matrix W can be split along its columns (column parallelism) or rows (row parallelism). Each GPU computes a portion of the output, and an all-reduce or all-gather combines the partial results.

Column Parallelism

Split W into columns: W = [W1 | W2]. GPU 0 computes XW1, GPU 1 computes XW2. The results are concatenated, requiring no communication in the forward pass (but an all-reduce in the backward pass). This is typically used for the first linear layer in the feed-forward network.

Row Parallelism

Split W into rows. Each GPU processes a different slice of the input. The partial outputs are summed via all-reduce in the forward pass. This is typically used for the second linear layer in the feed-forward network.

In Megatron-LM style parallelism, column and row parallelism are combined so that the MLP block requires only one all-reduce in the forward pass and one in the backward pass. Tensor parallelism requires very fast interconnects (NVLink within a node) because communication happens at every layer.

💡 Worked Example: Column-Parallel Matrix Multiply

Consider a feed-forward layer with input X (batch=2, d=4) and weight W (4x8), split across 2 GPUs:

GPU 0: W0 = first 4 columns of W (4x4). Computes Y0 = X · W0, producing a (2x4) result.

GPU 1: W1 = last 4 columns of W (4x4). Computes Y1 = X · W1, producing a (2x4) result.

Combine: Y = [Y0 | Y1] via concatenation (no communication needed). Each GPU did half the work, and the result is identical to a single GPU computing Y = X · W. The catch: the backward pass requires an all-reduce to sum gradients across GPUs.

ⓘ Analogy: Assembly Line vs. Task Division

Think of distributed training strategies as ways to organize a factory. Data parallelism is like opening duplicate factories that each build complete products from different orders. Tensor parallelism is like splitting each workstation across two workers who each handle half the parts. Pipeline parallelism is like an assembly line where each station does one step. Expert parallelism is like a specialized factory floor where different workers handle different product types, and a router directs each order to the right specialist.

5. Pipeline Parallelism

Pipeline parallelism assigns different layers of the model to different GPUs. GPU 0 runs layers 0-15, GPU 1 runs layers 16-31, and so on. The input flows through the pipeline, with each GPU passing its output to the next.

The naive approach has a severe pipeline bubble problem: while GPU 0 is processing the forward pass, GPUs 1-3 are idle, and while GPU 3 is processing the backward pass, GPUs 0-2 are idle. The 1F1B (one forward, one backward) schedule mitigates this by splitting each batch into micro-batches and interleaving forward and backward passes across micro-batches. This keeps all GPUs active most of the time, though a small bubble remains at the beginning and end of each batch.

Pipeline Parallelism: 1F1B Schedule GPU 0 GPU 1 GPU 2 GPU 3 Time → F1 F2 F3 F4 F1 F2 F3 F1 F2 F1 B1 F2 B2 B1 F3 B2 B1 F4 B2 B1 B2 bubble
Figure 6.6.2: The 1F1B pipeline schedule interleaves forward (F) and backward (B) micro-batches to minimize idle time (bubbles).

6. Mixed Precision Training

Mixed precision reduces memory usage and increases throughput by using lower-precision number formats for most computations while keeping critical accumulations in higher precision.

Format Bits Range Use Case
FP32 32 Very large Master weights, loss accumulation
FP16 16 Limited (needs loss scaling) Older GPUs (V100)
BF16 16 Same as FP32 Standard for modern training
FP8 (E4M3) 8 Limited Forward pass activations (Hopper+)
FP8 (E5M2) 8 Wider range, less precision Gradients (Hopper+)

FP8 Training at Scale

DeepSeek V3 demonstrated successful FP8 mixed-precision training at 671B parameters, the first large-scale demonstration of FP8 for LLM pre-training. The approach uses E4M3 format for forward pass activations (more precision, narrower range) and E5M2 for gradients (wider range, less precision). Per-tensor dynamic scaling factors are maintained to prevent overflow and underflow. FP8 training provides roughly 2x memory reduction and higher throughput compared to BF16 with minimal quality degradation.

7. Gradient Checkpointing

During the backward pass, computing gradients requires the activations from the forward pass. Normally all activations are stored in memory, consuming enormous amounts of GPU memory (proportional to batch size, sequence length, and hidden dimension). Gradient checkpointing (also called activation checkpointing) saves memory by storing only a subset of activations and recomputing the rest during the backward pass. The tradeoff is approximately 33% additional compute in exchange for a large reduction in activation memory.

# Gradient checkpointing in PyTorch
from torch.utils.checkpoint import checkpoint

class CheckpointedTransformer(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            # Recompute activations during backward instead of storing
            x = checkpoint(layer, x, use_reentrant=False)
        return x

# Memory comparison
seq_len, hidden, num_layers, batch = 2048, 4096, 32, 8
bytes_per_elem = 2  # BF16

# Without checkpointing: store all layer activations
no_ckpt = batch * seq_len * hidden * num_layers * bytes_per_elem
# With checkpointing: store only input to each checkpointed segment
with_ckpt = batch * seq_len * hidden * 1 * bytes_per_elem  # only 1 activation

print(f"Activation memory without checkpointing: {no_ckpt / 1e9:.1f} GB")
print(f"Activation memory with checkpointing:    {with_ckpt / 1e9:.2f} GB")
print(f"Memory saved: {(1 - with_ckpt/no_ckpt)*100:.0f}%")
Activation memory without checkpointing: 4.3 GB Activation memory with checkpointing: 0.13 GB Memory saved: 97%

8. Combining Parallelism Strategies

Real-world large-scale training combines multiple parallelism strategies in a hierarchy. A common configuration for training a 70B model on 512 GPUs might use tensor parallelism with degree 8 within each node (leveraging fast NVLink), pipeline parallelism with degree 8 across nodes, and data parallelism with degree 8 across pipeline-parallel groups. This 3D parallelism approach matches each strategy to the communication bandwidth available at each level of the hardware hierarchy.

ⓘ Note

The choice of parallelism strategy depends on the hardware topology. Tensor parallelism demands the highest bandwidth and should use intra-node NVLink (600 GB/s on H100). Pipeline parallelism can tolerate lower bandwidth and can span nodes connected via InfiniBand (400 Gb/s). Data parallelism is the most bandwidth-efficient and can span the widest network distances.

Check Your Understanding

1. What is the fundamental difference between DDP and FSDP?
Show Answer
DDP replicates the full model, optimizer states, and gradients on every GPU. Each GPU processes different data and synchronizes gradients via all-reduce. FSDP shards (splits) the model parameters, gradients, and optimizer states across GPUs so each GPU stores only a fraction. FSDP reconstructs full parameters on-demand for each layer's computation via all-gather, and reduces gradients via reduce-scatter. DDP uses more memory but less communication; FSDP uses less memory but more communication.
2. Why is BF16 preferred over FP16 for LLM training?
Show Answer
BF16 (bfloat16) has the same exponent range as FP32 (8 exponent bits) but with reduced mantissa precision (7 bits instead of 23). FP16 has a much narrower range (5 exponent bits), which means gradients and activations can easily overflow or underflow during training, requiring loss scaling to compensate. BF16's wider range eliminates the need for loss scaling, making training simpler and more stable. The reduced mantissa precision has minimal impact on model quality because the training process is inherently noisy.
3. What causes pipeline bubbles and how does the 1F1B schedule mitigate them?
Show Answer
Pipeline bubbles are idle time on GPUs when they have no work to do. In naive pipeline parallelism, GPU 0 must finish the forward pass for the entire batch before GPU 1 can start, creating a cascade of idle time. The 1F1B schedule splits the batch into micro-batches and interleaves forward and backward passes. Once GPU 0 finishes the forward pass for micro-batch 1, it can start micro-batch 2's forward pass while GPU 1 processes micro-batch 1. This keeps all GPUs busy for most of the training step, though small bubbles remain at the start and end.
4. Explain the memory-compute tradeoff in gradient checkpointing.
Show Answer
Normally, all intermediate activations from the forward pass are kept in memory for use during the backward pass. Gradient checkpointing discards most of these activations and recomputes them during the backward pass when needed. This dramatically reduces activation memory (often by 90%+) at the cost of approximately 33% additional compute, because each checkpointed segment's forward pass is run twice: once during the original forward pass and once during the backward pass to recompute activations. This tradeoff is almost always worthwhile because memory, not compute, is typically the bottleneck.

Key Takeaways