Tell a model to predict the next word and it learns grammar. Mask a word and it learns meaning. Corrupt a span and it learns to complain about the corruption, then fix it anyway.
A Methodical TokenizerWhy does the training objective matter? A language model's pre-training objective is the task it solves trillions of times during training. This choice shapes everything: what the model learns to represent, what it can generate, and what downstream tasks it excels at. Causal language modeling produces powerful generators. Masked language modeling produces powerful encoders. Span corruption creates versatile encoder-decoders. Newer objectives like fill-in-the-middle and multi-token prediction push the boundaries further. Understanding these objectives is essential for selecting the right model for your application and for designing new training procedures.
This section builds directly on the Transformer architecture from Module 04 (encoder, decoder, and attention masks). Understanding of tokenization from Module 02 is assumed. The discussion of multi-token prediction connects forward to Section 7.2 (DeepSeek V3).
1. Causal Language Modeling (CLM)
Causal language modeling, also called autoregressive language modeling, trains a model to predict the next token given all previous tokens. Formally, given a sequence of tokens x = (x1, x2, ..., xT), the model maximizes:
The model processes tokens left-to-right, with a causal attention mask that prevents each position from attending to future positions. This makes CLM naturally suited for text generation: at inference time, the model generates one token at a time, feeding each prediction back as input for the next step.
Why CLM Dominates Modern LLMs
Several properties make CLM the preferred objective for large-scale models:
- Every token is a training signal: Unlike MLM (which only learns from masked tokens), CLM provides a gradient signal at every position in the sequence, making training more sample-efficient.
- Natural generation: The training objective exactly matches the inference procedure, avoiding train-test mismatch.
- Scalability: The causal attention mask enables efficient KV caching during inference, where previously computed key-value pairs are reused.
- Flexibility: The same model can be prompted for classification, generation, translation, and reasoning, all through the text completion interface.
# Implementing causal language modeling loss from scratch import torch import torch.nn.functional as F def causal_lm_loss(logits, labels): """ logits: (batch, seq_len, vocab_size) labels: (batch, seq_len) - same as input tokens shifted by 1 """ # Shift: predict token t+1 from position t shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() # Cross-entropy loss, ignoring padding tokens (label = -100) loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) return loss # Example: batch of 2, sequence length 5, vocab size 100 logits = torch.randn(2, 5, 100) labels = torch.randint(0, 100, (2, 5)) loss = causal_lm_loss(logits, labels) print(f"CLM Loss: {loss.item():.4f}") print(f"Perplexity: {torch.exp(loss).item():.2f}")
2. Masked Language Modeling (MLM)
Masked language modeling, introduced by BERT, randomly masks a fraction of input tokens and trains the model to reconstruct them from the surrounding (bidirectional) context. The standard recipe masks 15% of tokens, with 80% replaced by [MASK], 10% replaced by a random token, and 10% kept unchanged.
where M is the set of masked positions and x\M denotes the corrupted input (all tokens with masks applied).
Strengths and Limitations
MLM's bidirectionality is its greatest strength for understanding tasks. A model that sees both left and right context can build richer representations of each token. This is why BERT-style models dominate in classification, named entity recognition, and other tasks where the full context is available.
However, MLM has significant limitations. Only 15% of tokens provide training signal per forward pass, making it less sample-efficient than CLM. The [MASK] token never appears during inference, creating a train-test mismatch. And MLM models cannot naturally generate text because they assume access to future context during prediction.
3. Span Corruption and Denoising Objectives
T5 introduced span corruption, a variant of MLM that masks contiguous spans of tokens rather than individual tokens. A random 15% of tokens are selected, grouped into contiguous spans, and each span is replaced with a single unique sentinel token (like <extra_id_0>). The model then generates the corrupted spans in order, separated by sentinel tokens.
Why Spans Are Better Than Single Tokens
Span corruption is more efficient than single-token masking for several reasons. First, the target sequence is shorter because multiple masked tokens are represented by a single sentinel, reducing the computational cost of the decoder. Second, the model must predict multiple consecutive tokens per span, encouraging it to learn phrase-level and sentence-level patterns rather than just word-level predictions.
# Simulating T5 span corruption import random def span_corruption(tokens, mask_ratio=0.15, mean_span_length=3): """Apply T5-style span corruption to a token list.""" n = len(tokens) num_masked = int(n * mask_ratio) num_spans = max(1, num_masked // mean_span_length) # Generate random span starts and lengths mask = [False] * n masked_so_far = 0 for _ in range(num_spans): if masked_so_far >= num_masked: break span_len = random.randint(1, mean_span_length * 2) start = random.randint(0, n - 1) for i in range(start, min(start + span_len, n)): mask[i] = True masked_so_far += 1 # Build corrupted input and target corrupted, target = [], [] sentinel_id = 0 in_span = False for i, tok in enumerate(tokens): if mask[i]: if not in_span: corrupted.append(f"<extra_id_{sentinel_id}>") target.append(f"<extra_id_{sentinel_id}>") sentinel_id += 1 in_span = True target.append(tok) else: corrupted.append(tok) in_span = False return corrupted, target tokens = "The quick brown fox jumps over the lazy dog".split() random.seed(42) corrupted, target = span_corruption(tokens) print(f"Input: {' '.join(corrupted)}") print(f"Target: {' '.join(target)}")
UL2: Mixture of Denoisers
UL2 (Unified Language Learning, 2022) took the denoising approach further by mixing multiple corruption strategies during pre-training. It combined three modes: (1) R-denoiser (regular denoising, like T5, with short spans), (2) S-denoiser (sequential denoising, similar to prefix LM, masking a suffix), and (3) X-denoiser (extreme denoising, with long spans and high mask ratios). A mode token prepended to each example tells the model which type of denoising to perform. This produced a single model that excelled at both understanding and generation tasks.
4. Prefix Language Modeling
Prefix LM is a hybrid approach used by models like PaLM and GLM. The input is divided into a prefix (which uses bidirectional attention) and a suffix (which uses causal attention). The prefix provides full context for encoding the input, while the suffix generates output autoregressively. This combines the encoding strength of MLM with the generation capability of CLM.
Prefix LM is implemented simply by modifying the attention mask. For a sequence of length T where the first P tokens are the prefix, positions 1 through P attend to all positions 1 through P (bidirectional), while positions P+1 through T attend only to positions 1 through their own index (causal). No architectural changes are needed.
5. Fill-in-the-Middle (FIM)
Fill-in-the-middle is a training objective designed specifically for code models (like Codex, StarCoder, and CodeLlama). The key observation is that programmers frequently need to insert code at an arbitrary position within existing code, but standard CLM only supports appending tokens at the end.
FIM works by splitting a document into three parts: prefix, middle, and suffix. During training, these parts are rearranged so the model sees the prefix and suffix first, then generates the middle. The most common variant, called PSM (Prefix-Suffix-Middle), presents the input as:
An alternative variant, SPM (Suffix-Prefix-Middle), places the suffix before the prefix. The model learns to condition on both surrounding context to generate the infilling content.
# Fill-in-the-Middle (FIM) transformation import random def apply_fim(document, fim_rate=0.5, mode="PSM"): """Transform a document for FIM training.""" if random.random() > fim_rate: return document # Keep as regular CLM with probability (1 - fim_rate) # Choose a random split point for the middle section chars = list(document) n = len(chars) split1 = random.randint(0, n) split2 = random.randint(split1, n) prefix = document[:split1] middle = document[split1:split2] suffix = document[split2:] if mode == "PSM": return f"<PRE>{prefix}<SUF>{suffix}<MID>{middle}" elif mode == "SPM": return f"<SUF>{suffix}<PRE>{prefix}<MID>{middle}" # Example with code code = """def fibonacci(n): if n <= 1: return n return fibonacci(n-1) + fibonacci(n-2)""" random.seed(0) fim_example = apply_fim(code, fim_rate=1.0) print(fim_example)
FIM is remarkably efficient to add. The original paper showed that applying FIM transformations to 50% of training documents during standard CLM pre-training adds the infilling capability with essentially zero degradation to left-to-right generation quality. This makes it a "free" capability that all code models should include.
6. Multi-Token Prediction
Standard CLM predicts only one token ahead. Multi-token prediction (MTP), introduced by Meta in 2024 and later adopted by DeepSeek V3, trains the model to predict several future tokens simultaneously. The architecture adds N independent prediction heads to a shared transformer backbone. Head k predicts token xt+k given tokens x1, ..., xt.
Why Predict Multiple Tokens?
The benefits of multi-token prediction are both theoretical and practical:
- Better representations: To predict multiple future tokens simultaneously, the model must encode richer information about the sequence at each position, including longer-range dependencies.
- Improved sample efficiency: Each training example provides N times more gradient signal per position.
- Speculative decoding compatibility: The additional prediction heads can be used as draft models for speculative decoding, accelerating inference by 2-3x without any additional model.
- Diminishing returns: Experiments show most of the benefit comes from N=4 heads. Adding more heads increases training memory without proportional gains.
# Multi-token prediction: conceptual implementation import torch import torch.nn as nn class MultiTokenPredictionHead(nn.Module): """N independent prediction heads sharing a transformer backbone.""" def __init__(self, hidden_dim, vocab_size, n_heads=4): super().__init__() self.n_heads = n_heads # Each head: LayerNorm + Linear projection to vocab self.heads = nn.ModuleList([ nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, vocab_size) ) for _ in range(n_heads) ]) def forward(self, hidden_states, labels): """ hidden_states: (batch, seq_len, hidden_dim) labels: (batch, seq_len) - original token ids """ total_loss = 0.0 for k, head in enumerate(self.heads, start=1): logits = head(hidden_states) # (batch, seq_len, vocab) # Shift by k positions: predict token at t+k from position t shift_logits = logits[:, :-k, :].contiguous() shift_labels = labels[:, k:].contiguous() loss_k = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) total_loss += loss_k return total_loss / self.n_heads mtp = MultiTokenPredictionHead(hidden_dim=512, vocab_size=32000, n_heads=4) hidden = torch.randn(2, 128, 512) labels = torch.randint(0, 32000, (2, 128)) loss = mtp(hidden, labels) print(f"MTP Loss (4 heads): {loss.item():.4f}")
7. Comparison of Pre-training Objectives
| Objective | Architecture | Context | Training Efficiency | Best For |
|---|---|---|---|---|
| CLM | Decoder-only | Left-to-right | 100% tokens train | Generation, prompting |
| MLM | Encoder-only | Bidirectional | 15% tokens train | Classification, NER |
| Span Corruption | Encoder-Decoder | Bidirectional enc | 15% tokens, shorter target | Seq2seq, translation |
| Prefix LM | Decoder-only | Hybrid | Suffix tokens train | Conditional generation |
| FIM | Decoder-only | Prefix + Suffix | 100% tokens train | Code infilling |
| MTP | Decoder-only | Left-to-right | N x signals per position | Better representations |
DeepSeek V3 (2024) provided the strongest validation of multi-token prediction, using 4 prediction heads during pre-training of a 671B MoE model. The additional heads served double duty: improving representation quality during training and enabling self-speculative decoding during inference, eliminating the need for a separate draft model. See Section 7.2 for the full DeepSeek V3 architecture discussion.
All objectives discussed here assume a transformer backbone. An active line of research explores alternative architectures like Mamba (Gu and Dao, 2023) and RWKV, which replace attention with state-space models (SSMs) or linear recurrences. These achieve linear scaling with sequence length (versus quadratic for attention) and process tokens in constant memory during inference. Hybrid architectures like Jamba interleave attention and Mamba layers for the best of both worlds. While transformers remain dominant, SSMs are a rapidly maturing alternative covered by Stanford CS336 and Berkeley CS294.
Check Your Understanding
Show Answer
Show Answer
Show Answer
Show Answer
Key Takeaways
- CLM (next-token prediction) is the dominant pre-training objective because every token trains the model and the objective naturally matches autoregressive generation.
- MLM produces superior representations for understanding tasks by leveraging bidirectional context, but wastes 85% of tokens per sequence as training signal.
- Span corruption (T5) improves on MLM by masking contiguous spans, producing shorter and more efficient target sequences.
- Fill-in-the-middle adds infilling capability to CLM models at essentially zero cost by rearranging a fraction of training documents.
- Multi-token prediction enriches representations by requiring the model to plan further ahead, and enables faster inference through speculative decoding.
- The choice of pre-training objective shapes a model's strengths; there is no universally best objective, only the right one for your use case.