Adam was not the first optimizer, nor the last, but somehow it keeps showing up at every training run like a reliable friend who always picks the restaurant. You could do better, but you probably will not.
A Perpetually Converging GradientOptimizers are the engine of training. While the model architecture defines what can be learned and the data defines what is taught, the optimizer determines how efficiently the model learns. The choice of optimizer, learning rate schedule, and gradient handling strategy affects training speed, stability, final performance, and memory usage. This section covers the optimizers used in LLM training, from the ubiquitous Adam family to memory-efficient alternatives, along with the training dynamics phenomena that practitioners must understand to diagnose and prevent training failures.
This section assumes familiarity with gradient-based optimization basics and PyTorch tensor operations from Section 0.2. Understanding of transformer layer structure from Section 4.1 will help with the discussion of per-layer training dynamics.
1. Stochastic Gradient Descent and Its Limitations
Vanilla SGD updates parameters by subtracting the gradient scaled by a learning rate: θ ← θ − η ∇L(θ). While SGD with momentum works well for convolutional networks, it performs poorly for transformer training. The core issue is that transformers have parameters operating at vastly different scales: embedding matrices, attention projections, layer norms, and feed-forward layers all have different gradient magnitudes. A single learning rate cannot simultaneously provide appropriate step sizes for all of these.
2. Adam: The Default Optimizer for Transformers
Adam (Adaptive Moment Estimation) addresses this by maintaining per-parameter adaptive learning rates based on first and second moment estimates of the gradient. The update rules are:
where m̂ and v̂ are bias-corrected estimates (dividing by 1 − βt). The typical hyperparameters are β1 = 0.9, β2 = 0.999, and ε = 10−8.
Adam stores two additional state tensors (m and v) for every parameter. For a model with N parameters in FP32, the optimizer states require 2N floats, or 8N bytes. Combined with the parameters themselves (4N bytes) and gradients (4N bytes), the total memory footprint is 16N bytes. A 7B parameter model thus needs approximately 112 GB just for parameters, gradients, and optimizer states in FP32.
Key Insight: Your optimizer is bigger than your model. A 7B-parameter model in FP16 occupies 14 GB. But Adam stores two additional FP32 state tensors (momentum and variance) per parameter, consuming 7B x 4 x 2 = 56 GB. The optimizer states alone are 4x the model size. This is why memory-efficient optimizers and FSDP are not optional for large-scale training.
import torch # Understanding Adam's memory footprint model_params = 7e9 # 7B parameters bytes_per_param = 4 # FP32 param_memory = model_params * bytes_per_param gradient_memory = model_params * bytes_per_param adam_m_memory = model_params * bytes_per_param # first moment adam_v_memory = model_params * bytes_per_param # second moment total = param_memory + gradient_memory + adam_m_memory + adam_v_memory print(f"Parameters: {param_memory / 1e9:.1f} GB") print(f"Gradients: {gradient_memory / 1e9:.1f} GB") print(f"Adam m (1st mom): {adam_m_memory / 1e9:.1f} GB") print(f"Adam v (2nd mom): {adam_v_memory / 1e9:.1f} GB") print(f"Total: {total / 1e9:.1f} GB")
3. AdamW: Decoupled Weight Decay
Loshchilov and Hutter (2019) identified a subtle but important flaw in Adam's handling of weight decay (L2 regularization). In standard Adam, weight decay is applied to the gradient before the adaptive scaling, which means the effective regularization strength varies across parameters based on their second moment estimates. AdamW fixes this by decoupling weight decay from the gradient update:
Here, λ is the weight decay coefficient (typically 0.01 to 0.1), and it is applied uniformly to all parameters regardless of their gradient history. This decoupling is particularly important for transformers because different parameter groups (attention weights, embeddings, biases) benefit from consistent regularization strength.
AdamW is the de facto standard optimizer for LLM training. Nearly every major open LLM (GPT, LLaMA, Mistral, Qwen) uses AdamW. The decoupled weight decay provides more consistent regularization across the model's diverse parameter groups, leading to better generalization. Weight decay is typically not applied to bias terms or layer normalization parameters.
4. Memory-Efficient Optimizer Alternatives
Adafactor
Adafactor (Shazeer and Stern, 2018) reduces Adam's memory overhead by factoring the second-moment matrix. Instead of storing a full v tensor for each parameter matrix, Adafactor stores only the row and column statistics: two vectors whose outer product approximates the full matrix. For a weight matrix of shape (m, n), this reduces the second-moment storage from mn to m + n. Adafactor was used in the T5 model family.
8-bit Adam
Dettmers et al. (2022) showed that Adam's optimizer states (m and v) can be quantized to 8-bit integers with dynamic block-wise quantization, reducing optimizer memory by 75% with negligible impact on training quality. The key insight is that optimizer states do not need full precision: they accumulate slowly changing statistics, and small quantization errors are averaged out over training steps.
LION (Sign-Based Optimizer)
LION (Chen et al., 2023) takes a radically different approach. Instead of using the full gradient magnitude, LION uses only the sign of the momentum update: every parameter is updated by exactly +η or −η. This eliminates the second moment entirely, cutting optimizer memory in half compared to Adam. LION also uses a different momentum interpolation that mixes past momentum with the current gradient. Despite its simplicity, LION matches or exceeds AdamW on many vision and language tasks.
| Optimizer | States per Param | Memory (7B model) | Used In |
|---|---|---|---|
| AdamW | 2 (m, v) | ~56 GB | GPT, LLaMA, Mistral |
| Adafactor | ~0.5 (factored v) | ~21 GB | T5, PaLM |
| 8-bit Adam | 2 (quantized) | ~21 GB | Fine-tuning |
| LION | 1 (m only) | ~35 GB | Research, some vision |
5. Learning Rate Schedules
Warmup
Transformer training universally begins with a warmup phase where the learning rate increases linearly from near-zero to the peak value over several hundred to several thousand steps. Warmup is necessary because at initialization, the model's loss landscape is poorly conditioned: gradients can be very large and noisy. Starting with a high learning rate would cause catastrophic parameter updates that push the model into a bad region of the loss landscape from which it cannot recover. Warmup allows the optimizer's moment estimates to stabilize before applying large updates.
Cosine Decay
After warmup, the learning rate is typically decayed following a cosine schedule:
where T is the total number of training steps. Cosine decay provides a smooth, gradual reduction that spends most of the training budget at moderate learning rates. The minimum learning rate is typically set to 10% of the peak rate. Some variants use cosine decay with warm restarts, periodically resetting the schedule to escape local minima.
import numpy as np import matplotlib.pyplot as plt def lr_schedule(step, total_steps, warmup_steps, peak_lr, min_lr): """Warmup + cosine decay schedule.""" if step < warmup_steps: # Linear warmup return peak_lr * step / warmup_steps else: # Cosine decay progress = (step - warmup_steps) / (total_steps - warmup_steps) return min_lr + 0.5 * (peak_lr - min_lr) * (1 + np.cos(np.pi * progress)) # Typical LLM training schedule total_steps = 100000 warmup_steps = 2000 peak_lr = 3e-4 min_lr = 3e-5 steps = np.arange(total_steps) lrs = [lr_schedule(s, total_steps, warmup_steps, peak_lr, min_lr) for s in steps] print(f"Step 0: LR = {lrs[0]:.2e}") print(f"Step 1000: LR = {lrs[1000]:.2e}") print(f"Step 2000: LR = {lrs[2000]:.2e} (peak)") print(f"Step 50000: LR = {lrs[50000]:.2e}") print(f"Step 99999: LR = {lrs[99999]:.2e}")
WSD: Warmup-Stable-Decay
A newer schedule, increasingly preferred for large-scale training, is the Warmup-Stable-Decay (WSD) schedule (also called trapezoidal). Instead of continuously decaying the learning rate after warmup, WSD maintains a constant "stable" phase at the peak learning rate for most of training, then applies a rapid linear or cosine decay only during a short final phase (typically the last 10-20% of steps).
WSD has a crucial practical advantage: because the learning rate stays constant during the stable phase, you can evaluate checkpoints and decide later when to begin the decay phase. With cosine decay, the total training budget must be known in advance. WSD decouples the schedule from the total step count, making it ideal for continued pretraining and for models where the final training duration is not known at the start. Llama 3, DeepSeek V3, and Qwen 2.5 all used WSD schedules (see Section 7.2 for details on these models).
import numpy as np def wsd_schedule(step, total_steps, warmup_steps, stable_fraction, peak_lr, min_lr): """Warmup-Stable-Decay learning rate schedule.""" decay_start = int(total_steps * stable_fraction) if step < warmup_steps: return peak_lr * step / warmup_steps elif step < decay_start: return peak_lr # Constant during stable phase else: # Linear decay in the final phase decay_progress = (step - decay_start) / (total_steps - decay_start) return peak_lr + (min_lr - peak_lr) * decay_progress total_steps = 100000 warmup_steps = 2000 peak_lr, min_lr = 3e-4, 3e-5 # Stable phase covers 80% of training; decay covers last 20% for s in [0, 2000, 40000, 79999, 80000, 90000, 99999]: lr = wsd_schedule(s, total_steps, warmup_steps, 0.8, peak_lr, min_lr) print(f"Step {s:>6}: LR = {lr:.2e}")
Why WSD is winning: Cosine decay commits you to a fixed training budget. If you want to train longer, you must restart with a new schedule. WSD lets you extend the stable phase indefinitely and only apply the decay when you are ready to produce the final checkpoint. This makes WSD the natural choice for iterative training workflows and continued pretraining on new data.
6. Gradient Accumulation
Large batch sizes improve training stability and efficiency but require more GPU memory. Gradient accumulation simulates large batches without the memory cost: instead of processing the full batch at once, the training loop processes several smaller micro-batches, accumulating their gradients, and only performing the optimizer step after all micro-batches are complete.
# Gradient accumulation pseudocode accumulation_steps = 8 # Effective batch = micro_batch * 8 optimizer.zero_grad() for i, micro_batch in enumerate(dataloader): loss = model(micro_batch) / accumulation_steps # Scale loss loss.backward() # Accumulate gradients if (i + 1) % accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() optimizer.zero_grad()
The loss must be divided by the number of accumulation steps to keep the effective gradient magnitude consistent with a true large batch. Forgetting this division is a common bug that results in an effective learning rate that is too large by a factor of the accumulation count.
7. Training Dynamics: The Loss Landscape
The loss landscape of a transformer is a high-dimensional surface with complex geometry. Research by Li et al. (2018) showed that neural network loss landscapes contain wide, flat minima and narrow, sharp minima. Models that converge to flatter minima tend to generalize better because small perturbations to the parameters (as occur with different test inputs) cause smaller changes in loss.
The Grokking Phenomenon
Power et al. (2022) discovered a surprising training dynamic called "grokking": a model can memorize the training set early in training (achieving near-perfect training accuracy) while showing no generalization, and then, after many additional training steps, suddenly transition to perfect generalization. This delayed generalization can occur thousands of steps after the training loss has plateaued.
Grokking challenges the conventional wisdom that training should be stopped when the validation loss plateaus. The phenomenon has been explained through the lens of representation learning: the model first memorizes using inefficient representations, then gradually discovers the underlying algorithm, which requires more training steps to crystallize. Weight decay plays a critical role in enabling grokking by continuously pushing the model away from memorization-based solutions toward simpler, more generalizable representations.
Grokking is related to the double descent phenomenon (Nakkiran et al., 2019), where test loss first decreases, then increases at the interpolation threshold, then decreases again as model capacity grows further. Both phenomena suggest that models pass through a memorization regime before finding generalizable solutions. Weight decay and regularization help models traverse this landscape more quickly: without weight decay, grokking may not occur at all, because the model has no pressure to find simpler representations.
Maximal Update Parametrization (muP) solves a critical practical problem: hyperparameters tuned on small models do not transfer to large models under standard parametrizations. muP defines a parametrization where the optimal learning rate, initialization scale, and other hyperparameters remain stable as model width increases. This lets practitioners tune on a small model (e.g., 40M parameters) and directly transfer those settings to a much larger model (e.g., 6.7B), dramatically reducing the cost of hyperparameter search. Stanford CS336 covers muP as a key technique for efficient large-scale training.
8. Training Instabilities
Large-scale training runs frequently encounter instabilities that can derail training entirely if not handled.
Loss Spikes
Sudden jumps in training loss, sometimes by several orders of magnitude, are a common occurrence in LLM training. They are typically caused by outlier batches with unusual gradient distributions, numerical overflow in attention computations, or learning rate being too high for the current loss landscape curvature. Loss spikes can often be recovered from (the loss returns to its pre-spike trajectory), but severe spikes may corrupt optimizer states and require rolling back to a checkpoint.
Gradient Clipping
The universal mitigation for gradient instability is gradient clipping: scaling the gradient vector so its global norm does not exceed a threshold (typically 1.0). This prevents any single batch from causing a catastrophically large parameter update.
z-Loss Regularization
PaLM introduced z-loss, an auxiliary loss term that penalizes large logits in the output layer: Lz = α · log2(Z), where Z is the sum of exponentials in the softmax denominator. This prevents attention entropy collapse and reduces the frequency of loss spikes.
Check Your Understanding
Show Answer
Show Answer
loss.backward() after each one. Because PyTorch accumulates gradients by default (adding new gradients to existing ones), the accumulated gradient after K micro-batches is equivalent to the gradient computed on the full batch of size K times the micro-batch size. The loss is divided by K to normalize the gradient magnitude. The optimizer step is only performed after all K micro-batches, producing the same update as a single large batch.Show Answer
Show Answer
Key Takeaways
- AdamW is the standard optimizer for LLM training, providing per-parameter adaptive learning rates with properly decoupled weight decay.
- Memory cost of Adam is substantial: 2 state tensors per parameter, totaling 16 bytes per parameter in FP32 (parameters + gradients + two moment buffers).
- Adafactor, 8-bit Adam, and LION reduce optimizer memory through factorization, quantization, and sign-based updates respectively.
- Warmup + cosine decay is the universal learning rate schedule for LLM training. Warmup stabilizes early training; cosine decay provides smooth convergence.
- Gradient accumulation simulates large batch sizes without additional memory by accumulating gradients across micro-batches.
- Grokking demonstrates that generalization can emerge long after memorization, suggesting that extended training with weight decay enables the discovery of underlying structure.
- Gradient clipping and z-loss are critical stability mechanisms that prevent loss spikes and attention collapse during large-scale training.