Module 06 · Section 6.2

Pre-training Objectives & Paradigms

The objectives that teach language models to understand and generate text

Tell a model to predict the next word and it learns grammar. Mask a word and it learns meaning. Corrupt a span and it learns to complain about the corruption, then fix it anyway.

A Methodical Tokenizer
★ Big Picture

Why does the training objective matter? A language model's pre-training objective is the task it solves trillions of times during training. This choice shapes everything: what the model learns to represent, what it can generate, and what downstream tasks it excels at. Causal language modeling produces powerful generators. Masked language modeling produces powerful encoders. Span corruption creates versatile encoder-decoders. Newer objectives like fill-in-the-middle and multi-token prediction push the boundaries further. Understanding these objectives is essential for selecting the right model for your application and for designing new training procedures.

⚙ Prerequisites

This section builds directly on the Transformer architecture from Module 04 (encoder, decoder, and attention masks). Understanding of tokenization from Module 02 is assumed. The discussion of multi-token prediction connects forward to Section 7.2 (DeepSeek V3).

1. Causal Language Modeling (CLM)

Causal language modeling, also called autoregressive language modeling, trains a model to predict the next token given all previous tokens. Formally, given a sequence of tokens x = (x1, x2, ..., xT), the model maximizes:

LCLM = ∑t=1T log P(xt | x1, ..., xt-1; θ)

The model processes tokens left-to-right, with a causal attention mask that prevents each position from attending to future positions. This makes CLM naturally suited for text generation: at inference time, the model generates one token at a time, feeding each prediction back as input for the next step.

Why CLM Dominates Modern LLMs

Several properties make CLM the preferred objective for large-scale models:

# Implementing causal language modeling loss from scratch
import torch
import torch.nn.functional as F

def causal_lm_loss(logits, labels):
    """
    logits: (batch, seq_len, vocab_size)
    labels: (batch, seq_len) - same as input tokens shifted by 1
    """
    # Shift: predict token t+1 from position t
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()

    # Cross-entropy loss, ignoring padding tokens (label = -100)
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=-100
    )
    return loss

# Example: batch of 2, sequence length 5, vocab size 100
logits = torch.randn(2, 5, 100)
labels = torch.randint(0, 100, (2, 5))
loss = causal_lm_loss(logits, labels)
print(f"CLM Loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")
CLM Loss: 4.6075 Perplexity: 100.27

2. Masked Language Modeling (MLM)

Masked language modeling, introduced by BERT, randomly masks a fraction of input tokens and trains the model to reconstruct them from the surrounding (bidirectional) context. The standard recipe masks 15% of tokens, with 80% replaced by [MASK], 10% replaced by a random token, and 10% kept unchanged.

LMLM = ∑i ∈ M log P(xi | x\M; θ)

where M is the set of masked positions and x\M denotes the corrupted input (all tokens with masks applied).

Strengths and Limitations

MLM's bidirectionality is its greatest strength for understanding tasks. A model that sees both left and right context can build richer representations of each token. This is why BERT-style models dominate in classification, named entity recognition, and other tasks where the full context is available.

However, MLM has significant limitations. Only 15% of tokens provide training signal per forward pass, making it less sample-efficient than CLM. The [MASK] token never appears during inference, creating a train-test mismatch. And MLM models cannot naturally generate text because they assume access to future context during prediction.

Causal LM (GPT) The cat sat ??? Predicts: "on" (next token) Left context only Every position trains Best for: Generation Masked LM (BERT) The [MASK] sat on Predicts: "cat" (masked token) Bidirectional context Only 15% of tokens train Best for: Understanding
Figure 6.2.1: CLM sees only left context and trains on every token. MLM sees full context but trains only on masked positions.

3. Span Corruption and Denoising Objectives

T5 introduced span corruption, a variant of MLM that masks contiguous spans of tokens rather than individual tokens. A random 15% of tokens are selected, grouped into contiguous spans, and each span is replaced with a single unique sentinel token (like <extra_id_0>). The model then generates the corrupted spans in order, separated by sentinel tokens.

Why Spans Are Better Than Single Tokens

Span corruption is more efficient than single-token masking for several reasons. First, the target sequence is shorter because multiple masked tokens are represented by a single sentinel, reducing the computational cost of the decoder. Second, the model must predict multiple consecutive tokens per span, encouraging it to learn phrase-level and sentence-level patterns rather than just word-level predictions.

# Simulating T5 span corruption
import random

def span_corruption(tokens, mask_ratio=0.15, mean_span_length=3):
    """Apply T5-style span corruption to a token list."""
    n = len(tokens)
    num_masked = int(n * mask_ratio)
    num_spans = max(1, num_masked // mean_span_length)

    # Generate random span starts and lengths
    mask = [False] * n
    masked_so_far = 0
    for _ in range(num_spans):
        if masked_so_far >= num_masked:
            break
        span_len = random.randint(1, mean_span_length * 2)
        start = random.randint(0, n - 1)
        for i in range(start, min(start + span_len, n)):
            mask[i] = True
            masked_so_far += 1

    # Build corrupted input and target
    corrupted, target = [], []
    sentinel_id = 0
    in_span = False
    for i, tok in enumerate(tokens):
        if mask[i]:
            if not in_span:
                corrupted.append(f"<extra_id_{sentinel_id}>")
                target.append(f"<extra_id_{sentinel_id}>")
                sentinel_id += 1
                in_span = True
            target.append(tok)
        else:
            corrupted.append(tok)
            in_span = False

    return corrupted, target

tokens = "The quick brown fox jumps over the lazy dog".split()
random.seed(42)
corrupted, target = span_corruption(tokens)
print(f"Input:     {' '.join(corrupted)}")
print(f"Target:    {' '.join(target)}")
Input: The quick <extra_id_0> jumps over <extra_id_1> dog Target: <extra_id_0> brown fox <extra_id_1> the lazy

UL2: Mixture of Denoisers

UL2 (Unified Language Learning, 2022) took the denoising approach further by mixing multiple corruption strategies during pre-training. It combined three modes: (1) R-denoiser (regular denoising, like T5, with short spans), (2) S-denoiser (sequential denoising, similar to prefix LM, masking a suffix), and (3) X-denoiser (extreme denoising, with long spans and high mask ratios). A mode token prepended to each example tells the model which type of denoising to perform. This produced a single model that excelled at both understanding and generation tasks.

4. Prefix Language Modeling

Prefix LM is a hybrid approach used by models like PaLM and GLM. The input is divided into a prefix (which uses bidirectional attention) and a suffix (which uses causal attention). The prefix provides full context for encoding the input, while the suffix generates output autoregressively. This combines the encoding strength of MLM with the generation capability of CLM.

📝 Note

Prefix LM is implemented simply by modifying the attention mask. For a sequence of length T where the first P tokens are the prefix, positions 1 through P attend to all positions 1 through P (bidirectional), while positions P+1 through T attend only to positions 1 through their own index (causal). No architectural changes are needed.

5. Fill-in-the-Middle (FIM)

Fill-in-the-middle is a training objective designed specifically for code models (like Codex, StarCoder, and CodeLlama). The key observation is that programmers frequently need to insert code at an arbitrary position within existing code, but standard CLM only supports appending tokens at the end.

FIM works by splitting a document into three parts: prefix, middle, and suffix. During training, these parts are rearranged so the model sees the prefix and suffix first, then generates the middle. The most common variant, called PSM (Prefix-Suffix-Middle), presents the input as:

<PRE> prefix <SUF> suffix <MID> middle

An alternative variant, SPM (Suffix-Prefix-Middle), places the suffix before the prefix. The model learns to condition on both surrounding context to generate the infilling content.

# Fill-in-the-Middle (FIM) transformation
import random

def apply_fim(document, fim_rate=0.5, mode="PSM"):
    """Transform a document for FIM training."""
    if random.random() > fim_rate:
        return document  # Keep as regular CLM with probability (1 - fim_rate)

    # Choose a random split point for the middle section
    chars = list(document)
    n = len(chars)
    split1 = random.randint(0, n)
    split2 = random.randint(split1, n)

    prefix = document[:split1]
    middle = document[split1:split2]
    suffix = document[split2:]

    if mode == "PSM":
        return f"<PRE>{prefix}<SUF>{suffix}<MID>{middle}"
    elif mode == "SPM":
        return f"<SUF>{suffix}<PRE>{prefix}<MID>{middle}"

# Example with code
code = """def fibonacci(n):
    if n <= 1:
        return n
    return fibonacci(n-1) + fibonacci(n-2)"""

random.seed(0)
fim_example = apply_fim(code, fim_rate=1.0)
print(fim_example)
💡 Key Insight

FIM is remarkably efficient to add. The original paper showed that applying FIM transformations to 50% of training documents during standard CLM pre-training adds the infilling capability with essentially zero degradation to left-to-right generation quality. This makes it a "free" capability that all code models should include.

6. Multi-Token Prediction

Standard CLM predicts only one token ahead. Multi-token prediction (MTP), introduced by Meta in 2024 and later adopted by DeepSeek V3, trains the model to predict several future tokens simultaneously. The architecture adds N independent prediction heads to a shared transformer backbone. Head k predicts token xt+k given tokens x1, ..., xt.

LMTP = ∑k=1Nt=1T-k log Pk(xt+k | x1, ..., xt; θ)

Why Predict Multiple Tokens?

The benefits of multi-token prediction are both theoretical and practical:

Shared Transformer Backbone The cat sat on Head 1: next token predicts "the" Head 2: t+2 predicts "mat" Head 3: t+3 predicts "." Multi-Token Prediction: N independent heads share one backbone Each head predicts a different future position independently
Figure 6.2.2: Multi-token prediction uses N independent heads on top of a shared backbone, each predicting a different future token.
# Multi-token prediction: conceptual implementation
import torch
import torch.nn as nn

class MultiTokenPredictionHead(nn.Module):
    """N independent prediction heads sharing a transformer backbone."""

    def __init__(self, hidden_dim, vocab_size, n_heads=4):
        super().__init__()
        self.n_heads = n_heads
        # Each head: LayerNorm + Linear projection to vocab
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.Linear(hidden_dim, vocab_size)
            )
            for _ in range(n_heads)
        ])

    def forward(self, hidden_states, labels):
        """
        hidden_states: (batch, seq_len, hidden_dim)
        labels: (batch, seq_len) - original token ids
        """
        total_loss = 0.0
        for k, head in enumerate(self.heads, start=1):
            logits = head(hidden_states)  # (batch, seq_len, vocab)
            # Shift by k positions: predict token at t+k from position t
            shift_logits = logits[:, :-k, :].contiguous()
            shift_labels = labels[:, k:].contiguous()
            loss_k = nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100
            )
            total_loss += loss_k
        return total_loss / self.n_heads

mtp = MultiTokenPredictionHead(hidden_dim=512, vocab_size=32000, n_heads=4)
hidden = torch.randn(2, 128, 512)
labels = torch.randint(0, 32000, (2, 128))
loss = mtp(hidden, labels)
print(f"MTP Loss (4 heads): {loss.item():.4f}")

7. Comparison of Pre-training Objectives

Objective Architecture Context Training Efficiency Best For
CLM Decoder-only Left-to-right 100% tokens train Generation, prompting
MLM Encoder-only Bidirectional 15% tokens train Classification, NER
Span Corruption Encoder-Decoder Bidirectional enc 15% tokens, shorter target Seq2seq, translation
Prefix LM Decoder-only Hybrid Suffix tokens train Conditional generation
FIM Decoder-only Prefix + Suffix 100% tokens train Code infilling
MTP Decoder-only Left-to-right N x signals per position Better representations
ⓘ MTP Validated at Scale

DeepSeek V3 (2024) provided the strongest validation of multi-token prediction, using 4 prediction heads during pre-training of a 671B MoE model. The additional heads served double duty: improving representation quality during training and enabling self-speculative decoding during inference, eliminating the need for a separate draft model. See Section 7.2 for the full DeepSeek V3 architecture discussion.

🔮 Where This Leads Next: Beyond Attention

All objectives discussed here assume a transformer backbone. An active line of research explores alternative architectures like Mamba (Gu and Dao, 2023) and RWKV, which replace attention with state-space models (SSMs) or linear recurrences. These achieve linear scaling with sequence length (versus quadratic for attention) and process tokens in constant memory during inference. Hybrid architectures like Jamba interleave attention and Mamba layers for the best of both worlds. While transformers remain dominant, SSMs are a rapidly maturing alternative covered by Stanford CS336 and Berkeley CS294.

Check Your Understanding

1. Why does CLM provide more training signal per sequence than MLM?
Show Answer
In CLM, every token in the sequence contributes to the loss (predicting the next token at each position), so 100% of tokens provide gradient signal. In MLM, only the 15% of tokens that are masked contribute to the loss. For a 512-token sequence, CLM gets ~511 prediction targets while MLM gets only ~77.
2. How does FIM avoid degrading standard left-to-right generation quality?
Show Answer
FIM applies the fill-in-the-middle transformation to only a fraction (typically 50%) of training documents. The other 50% remain as standard left-to-right sequences. The model thus learns both capabilities simultaneously. Experiments show that this split introduces essentially zero degradation to autoregressive generation while adding the infilling capability.
3. What advantage does multi-token prediction offer for inference speed?
Show Answer
The additional prediction heads from MTP training can serve as draft models for speculative decoding. During inference, the auxiliary heads propose multiple future tokens in parallel, and the main head verifies them. If the drafts match the main model's distribution, multiple tokens are accepted in a single forward pass, providing 2-3x speedup with no quality loss. This is especially valuable because no separate draft model needs to be loaded.
4. Explain the difference between T5's span corruption and BERT's token-level masking.
Show Answer
BERT masks individual tokens independently (each token has a 15% chance of being masked). T5's span corruption selects contiguous spans of tokens and replaces each span with a single sentinel token. This means T5's corrupted input is shorter (fewer sentinel tokens than the number of masked tokens in BERT), and the target sequence is also shorter. Additionally, T5 must predict multiple consecutive tokens per span, learning phrase-level patterns rather than just individual word predictions.

Key Takeaways