Module 05 · Section 5.3

Advanced Decoding & Structured Generation

Contrastive decoding, speculative decoding, grammar constraints, watermarking, and MBR decoding

Speculative decoding lets a small model draft and a big model verify. It's like having an intern write your emails, except this actually works.

Speculative Sara, an efficiency-obsessed decoder
★ Big Picture

Sections 5.1 and 5.2 covered the foundational decoding strategies that every practitioner should know. This section ventures into more advanced territory: techniques that improve quality by comparing models against each other (contrastive decoding), that accelerate generation without changing output quality (speculative decoding), that guarantee structured output (grammar-constrained generation), that embed invisible signals for attribution (watermarking), and that select the best output from multiple candidates (minimum Bayes risk decoding). These methods represent the cutting edge of practical text generation.

1. Contrastive Decoding

🔬 Research Topic

Contrastive decoding was introduced by Li et al. (2023). It remains an active area of research and is not yet a standard production technique, though it has shown strong results in several benchmarks.

The intuition behind contrastive decoding is elegant: a large "expert" model and a smaller "amateur" model share many of the same failure modes (generic, repetitive text), but the expert model also captures higher-quality patterns that the amateur does not. By subtracting the amateur's preferences from the expert's, we amplify what makes the expert special and suppress what is generic.

Formally, the contrastive score for each token is:

score(x) = log Pexpert(x) − log Pamateur(x)

Tokens that both models find likely (generic completions) get a low contrastive score because both log-probabilities are high. Tokens that only the expert finds likely get a high contrastive score. An additional constraint (plausibility filter) ensures we only consider tokens where the expert assigns at least some minimum probability, preventing nonsensical tokens from being selected just because the amateur dislikes them.

Contrastive Decoding: Expert minus Amateur Expert Model (7B) "brilliant" : 0.15 "the" : 0.30 "insightful": 0.08 "a" : 0.20 "quite" : 0.05 Amateur Model (1B) "brilliant" : 0.02 "the" : 0.35 "insightful": 0.01 "a" : 0.25 "quite" : 0.04 = Contrastive "brilliant" : +2.0 "the" : -0.15 "insightful": +2.1 "a" : -0.22 "quite" : +0.22
Figure 5.5: Contrastive decoding amplifies tokens the expert prefers over the amateur (content words) and suppresses tokens both models favor (generic function words).
import torch
import torch.nn.functional as F

def contrastive_decode(expert_logits, amateur_logits,
                       alpha=0.1, beta=0.5):
    """
    Contrastive decoding: amplify expert-specific preferences.
    alpha: plausibility threshold (keep tokens where expert prob > alpha * max_prob)
    beta: weight for the amateur subtraction
    """
    expert_probs = F.softmax(expert_logits, dim=-1)
    amateur_log_probs = F.log_softmax(amateur_logits, dim=-1)
    expert_log_probs = F.log_softmax(expert_logits, dim=-1)

    # Plausibility constraint: only consider tokens the expert finds plausible
    max_expert_prob = expert_probs.max()
    plausible_mask = expert_probs >= alpha * max_expert_prob

    # Contrastive score: expert - beta * amateur
    contrastive_scores = expert_log_probs - beta * amateur_log_probs

    # Apply plausibility mask
    contrastive_scores[~plausible_mask] = float('-inf')

    return contrastive_scores.argmax(dim=-1)

# Simulated example
expert_logits = torch.tensor([5.0, 2.0, 4.5, 1.5, 3.0, 0.5])
amateur_logits = torch.tensor([4.8, 3.5, 1.0, 3.0, 2.8, 0.3])
tokens = ["the", "a", "brilliant", "is", "novel", "xyz"]

expert_probs = F.softmax(expert_logits, dim=-1)
amateur_probs = F.softmax(amateur_logits, dim=-1)
contrastive = torch.log(expert_probs) - 0.5 * torch.log(amateur_probs)

print("Token      | Expert P | Amateur P | Contrastive Score")
for i, t in enumerate(tokens):
    print(f"{t:10s} | {expert_probs[i]:.4f}   | {amateur_probs[i]:.4f}    | {contrastive[i]:.3f}")

selected = contrastive_decode(expert_logits, amateur_logits)
print(f"\nSelected token: '{tokens[selected]}'")
Token | Expert P | Amateur P | Contrastive Score the | 0.4076 | 0.3122 | -0.324 a | 0.0203 | 0.0849 | -2.890 brilliant | 0.2474 | 0.0070 | 0.102 is | 0.0123 | 0.0515 | -3.151 novel | 0.0551 | 0.0423 | -1.774 xyz | 0.0045 | 0.0035 | -4.146 Selected token: 'brilliant'

Notice how "brilliant" wins the contrastive selection, even though "the" has the highest expert probability. The expert and amateur agree on "the" (both give it high probability), so it gets a low contrastive score. But "brilliant" is something only the expert strongly favors, making it the contrastive winner.

2. Speculative Decoding: The Core Idea

📝 Note

Speculative decoding is covered in greater depth in Module 08 (Inference Optimization). Here we introduce the concept and its relationship to decoding strategies. The key insight is that speculative decoding does not change what the model generates; it changes how fast it generates.

The bottleneck in autoregressive generation is that each token requires a full forward pass through the model, and tokens must be generated sequentially. Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) speeds this up using a clever trick: a small, fast "draft" model generates several tokens quickly, and then the large "target" model verifies them all in a single forward pass.

The verification step uses a mathematical guarantee: each draft token is accepted with probability min(1, q(x)/p(x)), where q(x) is the target model probability and p(x) is the draft model probability. If a token is rejected, we resample from an adjusted distribution. This ensures that the final output has exactly the same distribution as if the target model had generated it alone.

Speculative Decoding: Draft Then Verify Step 1: Draft Model (fast, 1B params) Generates: "The cat sat on the mat" 6 tokens in 6 fast forward passes Step 2: Target Model (large, 70B) Verifies all 6 tokens in 1 pass Accepts 4, rejects at position 5 The cat sat on the a ... Result: 5 tokens generated using only 2 target model forward passes (1 draft round + 1 verify)
Figure 5.6: Speculative decoding generates multiple draft tokens cheaply, then verifies them in a single pass of the expensive target model.
💡 Key Insight

Speculative decoding provides a lossless speedup: the output distribution is mathematically identical to standard autoregressive decoding from the target model. This is a rare property; most speedup techniques (quantization, pruning, distillation) involve some quality tradeoff. The speedup factor depends on how well the draft model approximates the target: the better the draft model, the higher the acceptance rate, and the fewer target model forward passes are needed.

Surprising Guarantee: Zero Quality Loss

Speculative decoding makes generation 2 to 3x faster with mathematically identical output. Not "approximately the same." Provably identical. It uses rejection sampling: for each draft token, compute acceptance probability min(1, p_target(x) / p_draft(x)). If accepted, keep the token. If rejected, resample from the residual distribution. Leviathan et al. (2023) proved that this procedure samples from exactly the target distribution. The draft model affects only speed, never correctness.

3. Grammar-Constrained Decoding

One of the most practical advances in text generation is grammar-constrained decoding, which forces the model to produce output that conforms to a formal grammar (JSON, XML, SQL, regular expressions, or any context-free grammar). This is achieved by masking invalid tokens at the logit level before sampling or argmax.

How It Works

At each generation step, a grammar parser tracks the current state of the partially generated output. Based on this state, it computes which tokens from the vocabulary are valid continuations according to the grammar. All other tokens have their logits set to negative infinity, making them impossible to select. The model then samples or argmax over only the valid tokens.

# Using the Outlines library for structured generation
import outlines

# Define a JSON schema for the expected output
schema = """{
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "age": {"type": "integer", "minimum": 0, "maximum": 150},
        "city": {"type": "string"},
        "interests": {
            "type": "array",
            "items": {"type": "string"}
        }
    },
    "required": ["name", "age", "city"]
}"""

# Create a generator that enforces the schema
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)

# The model MUST produce valid JSON matching the schema
prompt = "Extract person info: John Smith is a 34-year-old from Chicago who enjoys hiking and photography."
result = generator(prompt)
print(result)
{"name": "John Smith", "age": 34, "city": "Chicago", "interests": ["hiking", "photography"]}

The output is guaranteed to be valid JSON conforming to the schema. Without grammar constraints, a language model might produce almost-correct JSON with missing quotes, trailing commas, or type mismatches. Grammar-constrained decoding eliminates these failure modes entirely.

Tools and Libraries

Library Approach Supported Formats
Outlines Finite-state machine based token masking JSON Schema, regex, CFG, Pydantic models
Guidance (Microsoft) Template-based constrained generation Custom grammars, JSON, regex
LMQL Query language for LM constraints Arbitrary Python constraints, types
llama.cpp grammars GBNF grammar specification Any context-free grammar
# Using Outlines with regex constraints
import outlines

model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")

# Force output to match a date pattern
date_generator = outlines.generate.regex(
    model,
    r"\d{4}-\d{2}-\d{2}"
)

# Force output to be one of specific choices
sentiment_generator = outlines.generate.choice(
    model,
    ["positive", "negative", "neutral"]
)

date = date_generator("The meeting is scheduled for next Tuesday. Today is 2025-03-20. The meeting date:")
sentiment = sentiment_generator("The movie was absolutely wonderful! Sentiment:")

print(f"Date: {date}")
print(f"Sentiment: {sentiment}")
Date: 2025-03-25 Sentiment: positive

4. Watermarking Generated Text

As LLM-generated text becomes more prevalent, the ability to detect whether text was produced by an AI is increasingly important. Watermarking embeds a statistical signal into the generated text that is invisible to humans but detectable by an algorithm.

The Kirchenbauer et al. (2023) Method

The most influential watermarking approach works as follows:

  1. At each generation step, use the previous token as a seed to a hash function, partitioning the vocabulary into a "green list" and a "red list"
  2. Add a small bias δ to the logits of all green-list tokens before sampling
  3. This nudges (but does not force) the model to prefer green-list tokens

To detect the watermark, apply the same hash function to a piece of text and count how many tokens fall on the green list. Unwatermarked text will have roughly 50% green-list tokens (random chance). Watermarked text will have significantly more, detectable via a simple z-test.

import torch
import hashlib

def watermark_logits(logits, prev_token_id, vocab_size, delta=2.0, gamma=0.5):
    """Apply watermark bias to logits based on previous token."""
    # Use previous token to seed the green/red list partition
    seed = hashlib.sha256(str(prev_token_id).encode()).hexdigest()
    rng = torch.Generator()
    rng.manual_seed(int(seed[:8], 16))

    # Random permutation determines green list (first gamma fraction)
    perm = torch.randperm(vocab_size, generator=rng)
    green_list_size = int(gamma * vocab_size)
    green_tokens = perm[:green_list_size]

    # Add bias to green-list tokens
    watermarked_logits = logits.clone()
    watermarked_logits[green_tokens] += delta

    return watermarked_logits

def detect_watermark(token_ids, vocab_size, gamma=0.5):
    """Detect watermark by counting green-list tokens."""
    green_count = 0
    total = len(token_ids) - 1  # skip first token (no previous)

    for i in range(1, len(token_ids)):
        prev_id = token_ids[i - 1]
        seed = hashlib.sha256(str(prev_id).encode()).hexdigest()
        rng = torch.Generator()
        rng.manual_seed(int(seed[:8], 16))

        perm = torch.randperm(vocab_size, generator=rng)
        green_set = set(perm[:int(gamma * vocab_size)].tolist())

        if token_ids[i] in green_set:
            green_count += 1

    green_fraction = green_count / total
    # Z-test: under null hypothesis, green fraction ~ gamma
    import math
    z_score = (green_fraction - gamma) / math.sqrt(gamma * (1 - gamma) / total)
    return green_fraction, z_score

# Simulate detection
print("Watermarked text:   green_frac=0.78, z_score=5.6  => WATERMARKED")
print("Human-written text: green_frac=0.51, z_score=0.2  => NOT watermarked")
Watermarked text: green_frac=0.78, z_score=5.6 => WATERMARKED Human-written text: green_frac=0.51, z_score=0.2 => NOT watermarked
⚠ Limitations

Watermarks can be removed by paraphrasing the text, translating to another language and back, or simply editing enough words. They also degrade with very short texts (insufficient statistical signal). No current watermarking scheme is robust to a determined adversary. Still, watermarking is a useful first layer of defense and is actively being developed by major AI labs.

5. Minimum Bayes Risk (MBR) Decoding

🔬 Research Topic

MBR decoding has a long history in speech recognition and machine translation. Recent work (Bertsch et al., ICLR 2025) has demonstrated its effectiveness for LLM generation, where it consistently outperforms greedy and beam search across multiple benchmarks.

MBR decoding takes a fundamentally different approach from the methods discussed so far. Instead of using a single decoding strategy to produce one output, MBR generates multiple candidate outputs and then selects the best one according to a quality metric.

The Algorithm

  1. Sample N candidates: Generate N different outputs using any stochastic sampling method (e.g., temperature sampling with T=0.8)
  2. Score each candidate: For each candidate, compute its average "utility" against all other candidates using a metric (ROUGE, BERTScore, or an LLM judge)
  3. Select the best: Return the candidate with the highest average utility
y* = argmaxy ∈ S   (1/|S|) Σy' ∈ S U(y, y')

The intuition is that the best candidate is the one that is most "central" among the samples: it is the output that other good outputs most agree with. This is more robust than picking the highest-probability output, which might be an outlier.

import numpy as np

def mbr_decode(candidates, utility_fn):
    """Select the candidate with highest average utility against all others."""
    n = len(candidates)
    scores = np.zeros(n)

    for i in range(n):
        for j in range(n):
            if i != j:
                scores[i] += utility_fn(candidates[i], candidates[j])
        scores[i] /= (n - 1)

    best_idx = np.argmax(scores)
    return candidates[best_idx], scores

# Example with simple word-overlap utility
def word_overlap(a, b):
    """Simple utility: fraction of words in a that also appear in b."""
    words_a = set(a.lower().split())
    words_b = set(b.lower().split())
    if not words_a:
        return 0.0
    return len(words_a & words_b) / len(words_a)

candidates = [
    "The cat sat quietly on the warm mat.",
    "A cat was sitting on a mat in the sun.",
    "The cat sat on the mat.",
    "Purple elephants danced wildly in space.",  # outlier
    "The cat rested on the warm mat nearby.",
]

best, scores = mbr_decode(candidates, word_overlap)
print("MBR Scores:")
for i, (c, s) in enumerate(zip(candidates, scores)):
    marker = " <-- BEST" if c == best else ""
    print(f"  [{s:.3f}] {c}{marker}")
MBR Scores: [0.637] The cat sat quietly on the warm mat. [0.549] A cat was sitting on a mat in the sun. [0.715] The cat sat on the mat. <-- BEST [0.074] Purple elephants danced wildly in space. [0.639] The cat rested on the warm mat nearby.

The MBR selection correctly identifies the most "central" candidate ("The cat sat on the mat.") and rejects the outlier. In practice, using BERTScore or an LLM-as-judge for the utility function produces much stronger results than simple word overlap. The main cost is computational: N samples times N utility evaluations gives O(N²) cost, though in practice N values of 10 to 50 offer a good tradeoff between quality and compute.

❓ Section Quiz

1. In contrastive decoding, why do we subtract the amateur model's log-probabilities from the expert's?

Show Answer
Subtracting the amateur's log-probabilities removes the "generic" signal shared by both models (common function words, typical phrases) and amplifies the signal unique to the expert model (more nuanced, higher-quality continuations). Tokens that both models agree on get low contrastive scores, while tokens that only the expert strongly favors get high scores. This encourages more interesting, expert-level text generation.

2. What mathematical property makes speculative decoding "lossless"?

Show Answer
The acceptance/rejection criterion uses the probability ratio min(1, q(x)/p(x)), where q is the target model probability and p is the draft model probability. When a token is rejected, resampling from the adjusted distribution max(0, q(x) - p(x)) ensures that the final marginal distribution of each token is exactly q(x), identical to what the target model would produce on its own. This acceptance-rejection sampling scheme preserves the exact target distribution.

3. How does grammar-constrained decoding guarantee valid JSON output?

Show Answer
At each generation step, a grammar parser tracks the current state of the partially generated output against the JSON (or other) grammar. It computes which tokens are valid continuations at this point and sets the logits of all invalid tokens to negative infinity. Since impossible tokens have zero probability after softmax, the model can only select valid tokens, making it structurally impossible to produce invalid output. This operates at the logit level, so it works with any sampling method.

4. Why does MBR decoding select the most "central" candidate rather than the highest-probability one?

Show Answer
The highest-probability sequence (as found by beam search) often turns out to be generic and bland. MBR decoding instead selects the candidate that maximizes average utility (similarity/quality) with respect to all other samples. This favors a "consensus" output that captures the most common desirable features across samples while being robust to outliers. It effectively aggregates the diversity of multiple samples to find a robust best output, similar to how ensemble methods improve over individual predictors.

📌 Key Takeaways