Training a large model on one GPU is like reading the entire internet through a keyhole. Distributed training lets you knock down the wall, provided all the GPUs agree on which wall and when.
An Overly Synchronized ProcessNo single GPU can train a modern LLM. A 70B parameter model requires over 140 GB just for its parameters in FP16, far exceeding the memory of any single accelerator. Training such models demands distributing computation across dozens to thousands of GPUs, coordinating their work through high-speed interconnects. This section covers the four fundamental parallelism strategies (data, tensor, pipeline, and expert parallelism), the communication primitives that enable them, mixed-precision training (including FP8), and the memory optimization techniques that make large-scale training feasible.
This section assumes familiarity with PyTorch tensor operations from Section 0.2 and the transformer architecture from Module 04. Understanding of matrix multiplication is essential for the tensor parallelism discussion. The optimizer memory analysis from Section 6.5 motivates why distributed training is necessary.
1. Communication Primitives
Distributed training relies on collective communication operations to synchronize data between GPUs. Understanding these primitives is essential for reasoning about the communication overhead of different parallelism strategies.
| Primitive | Input | Output | Use Case |
|---|---|---|---|
| All-Reduce | Each GPU has a tensor | All GPUs have the sum | Gradient synchronization (DDP) |
| All-Gather | Each GPU has a shard | All GPUs have the full tensor | Parameter reconstruction (FSDP) |
| Reduce-Scatter | Each GPU has a tensor | Each GPU has a shard of the sum | Gradient sharding (FSDP) |
| Broadcast | One GPU has a tensor | All GPUs have it | Weight initialization |
These operations are implemented efficiently using ring or tree topologies. In a ring all-reduce with P GPUs, each GPU sends and receives 2(P−1)/P times the tensor size, giving near-optimal bandwidth utilization regardless of the number of GPUs. The NCCL library (NVIDIA Collective Communications Library) provides highly optimized implementations for NVIDIA GPUs.
2. Data Parallelism (DDP)
Distributed Data Parallelism is the simplest and most widely used form of parallelism. Each GPU holds a complete copy of the model and processes a different subset of the training data. After each forward-backward pass, gradients are synchronized across all GPUs using all-reduce, ensuring that all copies perform identical parameter updates.
# DDP training with PyTorch import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def setup_ddp(rank, world_size): dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def train_ddp(rank, world_size, model_class): setup_ddp(rank, world_size) # Each GPU gets a full model copy model = model_class().to(rank) model = DDP(model, device_ids=[rank]) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) for batch in dataloader: optimizer.zero_grad() loss = model(batch) loss.backward() # DDP auto-syncs gradients via all-reduce optimizer.step() dist.destroy_process_group()
3. Fully Sharded Data Parallelism (FSDP) and ZeRO
DDP's limitation is that every GPU must hold a complete copy of the model, gradients, and optimizer states. For a 7B model with AdamW, that is ~112 GB per GPU. FSDP (and the equivalent DeepSpeed ZeRO) resolves this by sharding these tensors across GPUs so each GPU stores only a fraction.
ZeRO Optimization Stages
- Stage 1: Shard optimizer states only. Each GPU stores 1/P of the optimizer states but keeps full parameters and gradients. Memory savings: ~4x reduction in optimizer memory.
- Stage 2: Shard optimizer states and gradients. After the backward pass, gradients are reduce-scattered so each GPU holds only its shard. Memory savings: further ~2x reduction.
- Stage 3: Shard everything (optimizer states, gradients, and parameters). Parameters are gathered on-demand for each layer's forward and backward pass, then discarded. Memory savings: total memory per GPU is 1/P of the full model state.
FSDP Stage 3 trades communication for memory. Each layer's forward pass requires an all-gather to reconstruct the full parameters, and each backward pass requires a reduce-scatter of the gradients. This means each parameter is communicated 3 times per training step (gather for forward, gather for backward, reduce-scatter for gradient). The communication overhead is significant but acceptable when the alternative is not being able to train the model at all.
# FSDP training with PyTorch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy # Mixed precision policy for FSDP mp_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) # Wrap model with FSDP (full sharding = ZeRO Stage 3) model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mp_policy, device_id=torch.cuda.current_device(), ) # Training loop is identical to standard PyTorch for batch in dataloader: optimizer.zero_grad() loss = model(batch) loss.backward() optimizer.step()
4. Tensor Parallelism
Tensor parallelism splits individual layers across GPUs. For a linear layer Y = XW, the weight matrix W can be split along its columns (column parallelism) or rows (row parallelism). Each GPU computes a portion of the output, and an all-reduce or all-gather combines the partial results.
Column Parallelism
Split W into columns: W = [W1 | W2]. GPU 0 computes XW1, GPU 1 computes XW2. The results are concatenated, requiring no communication in the forward pass (but an all-reduce in the backward pass). This is typically used for the first linear layer in the feed-forward network.
Row Parallelism
Split W into rows. Each GPU processes a different slice of the input. The partial outputs are summed via all-reduce in the forward pass. This is typically used for the second linear layer in the feed-forward network.
In Megatron-LM style parallelism, column and row parallelism are combined so that the MLP block requires only one all-reduce in the forward pass and one in the backward pass. Tensor parallelism requires very fast interconnects (NVLink within a node) because communication happens at every layer.
Consider a feed-forward layer with input X (batch=2, d=4) and weight W (4x8), split across 2 GPUs:
GPU 0: W0 = first 4 columns of W (4x4). Computes Y0 = X · W0, producing a (2x4) result.
GPU 1: W1 = last 4 columns of W (4x4). Computes Y1 = X · W1, producing a (2x4) result.
Combine: Y = [Y0 | Y1] via concatenation (no communication needed). Each GPU did half the work, and the result is identical to a single GPU computing Y = X · W. The catch: the backward pass requires an all-reduce to sum gradients across GPUs.
Think of distributed training strategies as ways to organize a factory. Data parallelism is like opening duplicate factories that each build complete products from different orders. Tensor parallelism is like splitting each workstation across two workers who each handle half the parts. Pipeline parallelism is like an assembly line where each station does one step. Expert parallelism is like a specialized factory floor where different workers handle different product types, and a router directs each order to the right specialist.
5. Pipeline Parallelism
Pipeline parallelism assigns different layers of the model to different GPUs. GPU 0 runs layers 0-15, GPU 1 runs layers 16-31, and so on. The input flows through the pipeline, with each GPU passing its output to the next.
The naive approach has a severe pipeline bubble problem: while GPU 0 is processing the forward pass, GPUs 1-3 are idle, and while GPU 3 is processing the backward pass, GPUs 0-2 are idle. The 1F1B (one forward, one backward) schedule mitigates this by splitting each batch into micro-batches and interleaving forward and backward passes across micro-batches. This keeps all GPUs active most of the time, though a small bubble remains at the beginning and end of each batch.
6. Mixed Precision Training
Mixed precision reduces memory usage and increases throughput by using lower-precision number formats for most computations while keeping critical accumulations in higher precision.
| Format | Bits | Range | Use Case |
|---|---|---|---|
| FP32 | 32 | Very large | Master weights, loss accumulation |
| FP16 | 16 | Limited (needs loss scaling) | Older GPUs (V100) |
| BF16 | 16 | Same as FP32 | Standard for modern training |
| FP8 (E4M3) | 8 | Limited | Forward pass activations (Hopper+) |
| FP8 (E5M2) | 8 | Wider range, less precision | Gradients (Hopper+) |
FP8 Training at Scale
DeepSeek V3 demonstrated successful FP8 mixed-precision training at 671B parameters, the first large-scale demonstration of FP8 for LLM pre-training. The approach uses E4M3 format for forward pass activations (more precision, narrower range) and E5M2 for gradients (wider range, less precision). Per-tensor dynamic scaling factors are maintained to prevent overflow and underflow. FP8 training provides roughly 2x memory reduction and higher throughput compared to BF16 with minimal quality degradation.
7. Gradient Checkpointing
During the backward pass, computing gradients requires the activations from the forward pass. Normally all activations are stored in memory, consuming enormous amounts of GPU memory (proportional to batch size, sequence length, and hidden dimension). Gradient checkpointing (also called activation checkpointing) saves memory by storing only a subset of activations and recomputing the rest during the backward pass. The tradeoff is approximately 33% additional compute in exchange for a large reduction in activation memory.
# Gradient checkpointing in PyTorch from torch.utils.checkpoint import checkpoint class CheckpointedTransformer(nn.Module): def __init__(self, layers): super().__init__() self.layers = nn.ModuleList(layers) def forward(self, x): for layer in self.layers: # Recompute activations during backward instead of storing x = checkpoint(layer, x, use_reentrant=False) return x # Memory comparison seq_len, hidden, num_layers, batch = 2048, 4096, 32, 8 bytes_per_elem = 2 # BF16 # Without checkpointing: store all layer activations no_ckpt = batch * seq_len * hidden * num_layers * bytes_per_elem # With checkpointing: store only input to each checkpointed segment with_ckpt = batch * seq_len * hidden * 1 * bytes_per_elem # only 1 activation print(f"Activation memory without checkpointing: {no_ckpt / 1e9:.1f} GB") print(f"Activation memory with checkpointing: {with_ckpt / 1e9:.2f} GB") print(f"Memory saved: {(1 - with_ckpt/no_ckpt)*100:.0f}%")
8. Combining Parallelism Strategies
Real-world large-scale training combines multiple parallelism strategies in a hierarchy. A common configuration for training a 70B model on 512 GPUs might use tensor parallelism with degree 8 within each node (leveraging fast NVLink), pipeline parallelism with degree 8 across nodes, and data parallelism with degree 8 across pipeline-parallel groups. This 3D parallelism approach matches each strategy to the communication bandwidth available at each level of the hardware hierarchy.
The choice of parallelism strategy depends on the hardware topology. Tensor parallelism demands the highest bandwidth and should use intra-node NVLink (600 GB/s on H100). Pipeline parallelism can tolerate lower bandwidth and can span nodes connected via InfiniBand (400 Gb/s). Data parallelism is the most bandwidth-efficient and can span the widest network distances.
Check Your Understanding
Show Answer
Show Answer
Show Answer
Show Answer
Key Takeaways
- DDP is the simplest distributed training approach: replicate the model on each GPU and synchronize gradients via all-reduce.
- FSDP/ZeRO shards parameters, gradients, and optimizer states across GPUs to reduce per-GPU memory, enabling training of much larger models.
- Tensor parallelism splits individual layers across GPUs and requires fast intra-node interconnects (NVLink).
- Pipeline parallelism assigns different layers to different GPUs; the 1F1B schedule minimizes idle time.
- BF16 is the standard precision for LLM training; FP8 (demonstrated by DeepSeek V3) provides further memory and throughput improvements on Hopper GPUs.
- Gradient checkpointing trades ~33% extra compute for massive activation memory savings.
- Real-world training uses 3D parallelism, combining tensor, pipeline, and data parallelism to match hardware topology.