Module 17 · Section 17.4

Explaining Transformers

Attribution methods for transformer predictions: attention rollout, gradient-weighted attention, LRP, and systematic comparison of explanation approaches
★ Big Picture

No single explanation method tells the whole truth about a transformer's predictions. Raw attention weights, gradient-based attribution, attention rollout, and perturbation methods each capture different aspects of how information flows through the network. Understanding the strengths, limitations, and failure modes of each method is essential for choosing the right tool and interpreting results correctly. This section provides a systematic comparison framework for transformer explanation methods, helping practitioners select approaches that match their specific needs.

1. The Explanation Problem

When a transformer model predicts a token, its prediction results from hundreds of attention heads and MLP layers interacting across dozens of layers. Explaining this prediction means answering some form of "which input tokens mattered and why?" Different explanation methods operationalize this question differently, leading to genuinely different (and sometimes contradictory) answers.

The core tension is between faithfulness (does the explanation accurately reflect the model's actual computation?) and plausibility (does the explanation make intuitive sense to a human?). An explanation that perfectly traces the model's computation might be incomprehensible, while an intuitively appealing explanation might not accurately reflect what the model actually did.

2. Attention Rollout

Raw attention weights from a single layer show direct token-to-token attention. But information flows through multiple layers, so a token at position 5 might influence position 10 indirectly by first attending to position 7, which then attends to position 10. Attention rollout (Abnar and Zuidema, 2020) accounts for this multi-hop information flow by multiplying attention matrices across layers.

Raw Attention (Single Layer) The cat sat on it "it" attends to "cat" (0.7) Only shows one layer's attention Misses indirect information flow Attention Rollout (All Layers) The cat sat on it "it" traces back to "cat" (0.45), "sat" (0.25), "The" (0.15) Accounts for indirect paths through all intermediate layers
Figure 17.7: Raw attention shows only direct attention at one layer. Attention rollout traces information flow across all layers by multiplying attention matrices, capturing indirect paths.
# Attention Rollout Implementation
import torch
import numpy as np

def attention_rollout(
    attentions,
    head_fusion="mean",
    discard_ratio=0.0,
):
    """
    Compute attention rollout across all layers.

    Args:
        attentions: tuple of attention tensors, one per layer
                   Each has shape (batch, num_heads, seq_len, seq_len)
        head_fusion: how to combine heads ("mean", "max", "min")
        discard_ratio: fraction of lowest attention weights to zero out

    Returns:
        rollout: (seq_len, seq_len) matrix of accumulated attention
    """
    num_layers = len(attentions)
    seq_len = attentions[0].shape[-1]

    # Start with identity (each token attends to itself)
    rollout = torch.eye(seq_len)

    for layer_idx in range(num_layers):
        attn = attentions[layer_idx].squeeze(0)  # (heads, seq, seq)

        # Fuse attention heads
        if head_fusion == "mean":
            attn_fused = attn.mean(dim=0)
        elif head_fusion == "max":
            attn_fused = attn.max(dim=0).values
        elif head_fusion == "min":
            attn_fused = attn.min(dim=0).values

        # Optionally discard low-attention connections
        if discard_ratio > 0:
            flat = attn_fused.flatten()
            threshold = flat.quantile(discard_ratio)
            attn_fused = attn_fused * (attn_fused >= threshold)
            # Re-normalize rows
            attn_fused = attn_fused / attn_fused.sum(dim=-1, keepdim=True)

        # Add residual connection (identity)
        attn_with_residual = 0.5 * attn_fused + 0.5 * torch.eye(seq_len)

        # Multiply with cumulative rollout
        rollout = attn_with_residual @ rollout

    return rollout

# Usage
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2", output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

text = "The cat sat on the mat because it was very tired"
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

rollout = attention_rollout(outputs.attentions)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

# Show which tokens "it" (position 7) attends to after rollout
it_idx = tokens.index("it") if "it" in tokens else 7
print(f"Attention rollout from '{tokens[it_idx]}':")
for i, (token, score) in enumerate(zip(tokens, rollout[it_idx])):
    bar = "#" * int(score.item() * 40)
    print(f"  {token:10s} {bar} ({score.item():.3f})")

3. Gradient-Weighted Attention

Gradient-weighted attention (also called Attention × Gradient) combines attention weights with gradient information. The intuition is that attention tells us where the model looks, while gradients tell us how sensitive the output is to what the model finds there. Multiplying these signals highlights tokens that the model both attends to and that actually influence the prediction.

# Gradient-Weighted Attention
import torch

def gradient_weighted_attention(
    model,
    tokenizer,
    text,
    target_pos=-1,
):
    """
    Compute gradient-weighted attention for each layer and head.

    Returns attention weights scaled by the gradient of the output
    with respect to the attention weights themselves.
    """
    model.eval()
    inputs = tokenizer(text, return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # Forward pass with attention outputs
    outputs = model(**inputs, output_attentions=True)
    attentions = outputs.attentions  # tuple of (1, heads, seq, seq)

    # Get the predicted token logit
    logits = outputs.logits[0, target_pos]
    predicted_token_id = logits.argmax()
    target_logit = logits[predicted_token_id]

    # Compute gradient of target logit w.r.t. each attention matrix
    grad_weighted = []
    for layer_attn in attentions:
        if layer_attn.requires_grad:
            grad = torch.autograd.grad(
                target_logit, layer_attn, retain_graph=True
            )[0]
            # Element-wise multiply: attention * gradient
            weighted = (layer_attn * grad).squeeze(0)
            # Average over heads
            weighted = weighted.mean(dim=0).detach()
            grad_weighted.append(weighted)

    # Stack and average across layers
    all_layers = torch.stack(grad_weighted)  # (layers, seq, seq)
    combined = all_layers.mean(dim=0)  # (seq, seq)

    return combined, tokens

# Alternative: enable gradients for attention
# Need to use hooks to capture attention with grad enabled
def compute_attention_gradient_attribution(model, tokenizer, text):
    """Simpler approach using hooks to capture attention gradients."""
    attention_grads = {}

    def save_attention_grad(name):
        def hook(module, grad_input, grad_output):
            attention_grads[name] = grad_output[0].detach()
        return hook

    # Register backward hooks on attention layers
    hooks = []
    for i, layer in enumerate(model.transformer.h):
        h = layer.attn.register_full_backward_hook(
            save_attention_grad(f"layer_{i}")
        )
        hooks.append(h)

    # Forward + backward
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    target = outputs.logits[0, -1].max()
    target.backward()

    # Clean up hooks
    for h in hooks:
        h.remove()

    return attention_grads

4. Layer-wise Relevance Propagation (LRP)

Layer-wise Relevance Propagation redistributes the model's output score backward through the network, assigning a relevance value to each neuron at each layer. At the input layer, these relevance values become token-level attributions. LRP satisfies a conservation property: the total relevance at each layer equals the output score, ensuring nothing is lost or created during propagation.

💡 Key Insight

LRP propagates relevance backward using a rule that distributes relevance proportionally to the contribution of each input. For a linear layer y = Wx + b, the relevance assigned to input x_i is proportional to how much x_i contributed to each output y_j, weighted by the relevance of y_j. The specific propagation rule (LRP-0, LRP-ε, LRP-γ) determines how to handle numerical stability and positive vs. negative contributions.

# Layer-wise Relevance Propagation for Transformers (simplified)
import torch
import torch.nn as nn

class TransformerLRP:
    """Simplified LRP for transformer models."""

    def __init__(self, model, epsilon=1e-6):
        self.model = model
        self.epsilon = epsilon
        self.activations = {}

    def register_hooks(self):
        """Register forward hooks to capture activations."""
        self.hooks = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                h = module.register_forward_hook(
                    self._save_activation(name)
                )
                self.hooks.append(h)

    def _save_activation(self, name):
        def hook(module, input, output):
            self.activations[name] = {
                "input": input[0].detach(),
                "output": output.detach(),
                "weight": module.weight.detach(),
            }
        return hook

    def propagate_linear(self, relevance, layer_name):
        """LRP-epsilon rule for a linear layer."""
        act = self.activations[layer_name]
        z = act["input"] @ act["weight"].T  # pre-activation
        z = z + self.epsilon * z.sign()  # stabilize
        s = relevance / z
        c = s @ act["weight"]
        relevance_input = act["input"] * c
        return relevance_input

    def attribute(self, text, tokenizer):
        """Compute LRP attribution for input tokens."""
        self.register_hooks()

        inputs = tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)

        # Start with the output logits as initial relevance
        logits = outputs.logits[0, -1]
        relevance = torch.zeros_like(logits)
        relevance[logits.argmax()] = logits.max()

        # Propagate backward through the network
        # (simplified: in practice, need to handle attention specially)
        for name in reversed(list(self.activations.keys())):
            if name.startswith("lm_head") or "mlp" in name:
                relevance = self.propagate_linear(relevance, name)

        # Clean up
        for h in self.hooks:
            h.remove()

        return relevance

5. Perturbation-Based Explanations

Perturbation methods explain predictions by measuring how the output changes when parts of the input are modified or removed. Unlike gradient-based methods (which measure local sensitivity), perturbation methods measure actual counterfactual impact: what would the model predict if this token were absent?

# Perturbation-based attribution methods
import torch
import numpy as np

def leave_one_out_attribution(model, tokenizer, text):
    """
    Measure each token's importance by removing it and
    observing the change in prediction confidence.
    """
    inputs = tokenizer(text, return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # Get baseline prediction
    with torch.no_grad():
        baseline_logits = model(**inputs).logits[0, -1]
    baseline_probs = torch.softmax(baseline_logits, dim=-1)
    predicted_id = baseline_probs.argmax()
    baseline_prob = baseline_probs[predicted_id].item()

    # Remove each token and measure impact
    attributions = []
    for i in range(len(tokens)):
        # Create input with token i replaced by padding/mask
        perturbed_ids = inputs["input_ids"].clone()
        perturbed_ids[0, i] = tokenizer.pad_token_id or 0

        with torch.no_grad():
            perturbed_logits = model(perturbed_ids).logits[0, -1]
        perturbed_prob = torch.softmax(perturbed_logits, dim=-1)[predicted_id].item()

        # Attribution = drop in probability when token is removed
        attribution = baseline_prob - perturbed_prob
        attributions.append(attribution)

    return np.array(attributions), tokens


def sliding_window_occlusion(
    model, tokenizer, text, window_size=3
):
    """
    Occlude a sliding window of tokens to find important regions.
    More robust than single-token removal for capturing multi-token patterns.
    """
    inputs = tokenizer(text, return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    seq_len = len(tokens)

    with torch.no_grad():
        baseline_logits = model(**inputs).logits[0, -1]
    predicted_id = baseline_logits.argmax()
    baseline_score = baseline_logits[predicted_id].item()

    region_scores = []
    for start in range(seq_len - window_size + 1):
        perturbed_ids = inputs["input_ids"].clone()
        for j in range(start, start + window_size):
            perturbed_ids[0, j] = tokenizer.pad_token_id or 0

        with torch.no_grad():
            perturbed_logits = model(perturbed_ids).logits[0, -1]
        score = perturbed_logits[predicted_id].item()

        impact = baseline_score - score
        region_scores.append({
            "start": start,
            "end": start + window_size,
            "tokens": tokens[start:start + window_size],
            "impact": impact,
        })

    # Sort by impact
    region_scores.sort(key=lambda x: x["impact"], reverse=True)
    return region_scores
⚠ Warning

Perturbation methods have a fundamental limitation: removing a token creates an out-of-distribution input. The model was never trained on inputs with random tokens in the middle of coherent text, so its behavior on perturbed inputs may not reflect what it would do if the token were genuinely absent. This is sometimes called the "off-manifold" problem. Methods like SHAP (Section 17.3) partially address this by marginalizing over replacements rather than using a fixed perturbation.

6. Comparing Explanation Methods

Different explanation methods can produce substantially different attributions for the same prediction. Choosing the right method requires understanding what each method measures and what properties matter for your use case.

Computational Cost Faithfulness High Low Low High Raw Attn Rollout Grad x Attn IG LRP Perturb /SHAP
Figure 17.8: Approximate positioning of explanation methods on faithfulness vs. computational cost. Perturbation-based methods are most faithful but most expensive; raw attention is cheapest but least faithful.
MethodWhat It MeasuresStrengthsWeaknesses
Raw attentionWhere the model "looks"Fast, intuitive, no extra computationNot faithful; ignores values and MLPs
Attention rolloutMulti-layer information flowCaptures indirect paths across layersStill ignores value vectors and MLPs
Grad x AttentionGradient-weighted attention flowCombines where model looks with sensitivityGradients can be noisy, local approximation
Integrated GradientsPath-integrated sensitivityAxiom-satisfying, theoretically groundedBaseline choice matters, can be expensive
LRPBackward relevance propagationConservation property, layer-specificComplex to implement for attention layers
Perturbation/SHAPCounterfactual impact of removalMost direct measure of importanceOff-manifold problem, very expensive
📝 Note

Recent work suggests that no attribution method consistently outperforms others across all evaluation metrics. The best choice depends on the application: for quick debugging, raw attention or attention rollout provides fast insight; for regulatory compliance requiring faithful explanations, Integrated Gradients or SHAP is more appropriate; for mechanistic understanding, activation patching (Section 17.2) provides the strongest causal evidence.

# Unified comparison framework for explanation methods
import torch
import numpy as np
from typing import Dict, Callable, List

def compare_attribution_methods(
    model,
    tokenizer,
    text: str,
    methods: Dict[str, Callable],
) -> Dict[str, np.ndarray]:
    """
    Run multiple attribution methods on the same input
    and compare their outputs.
    """
    results = {}
    tokens = tokenizer.convert_ids_to_tokens(
        tokenizer(text, return_tensors="pt")["input_ids"][0]
    )

    for name, method_fn in methods.items():
        attributions = method_fn(model, tokenizer, text)
        # Normalize to [0, 1] for comparison
        attr_min = attributions.min()
        attr_max = attributions.max()
        if attr_max > attr_min:
            normalized = (attributions - attr_min) / (attr_max - attr_min)
        else:
            normalized = np.zeros_like(attributions)
        results[name] = normalized

    # Compute agreement metrics
    method_names = list(results.keys())
    print(f"Attribution comparison for: '{text}'")
    print(f"Predicted next token: {get_prediction(model, tokenizer, text)}")
    print()

    # Rank correlation between methods
    from scipy.stats import spearmanr
    print("Spearman rank correlations:")
    for i, name_a in enumerate(method_names):
        for name_b in method_names[i+1:]:
            corr, pval = spearmanr(results[name_a], results[name_b])
            print(f"  {name_a} vs {name_b}: rho={corr:.3f} (p={pval:.3f})")

    # Top-3 tokens per method
    print("\nTop-3 most important tokens per method:")
    for name, attrs in results.items():
        top_3 = np.argsort(attrs)[-3:][::-1]
        top_tokens = [(tokens[i], attrs[i]) for i in top_3]
        top_str = ", ".join(f"'{t}' ({s:.2f})" for t, s in top_tokens)
        print(f"  {name:20s}: {top_str}")

    return results

# Run comparison
methods = {
    "raw_attention": lambda m, t, x: extract_raw_attention(m, t, x),
    "rollout": lambda m, t, x: compute_rollout(m, t, x),
    "integrated_gradients": lambda m, t, x: integrated_gradients(m, t, x)[0],
    "leave_one_out": lambda m, t, x: leave_one_out_attribution(m, t, x)[0],
}

results = compare_attribution_methods(
    model, tokenizer,
    "The Eiffel Tower is located in the city of",
    methods
)
💡 Key Insight

When different attribution methods disagree about which tokens are important, this is informative rather than problematic. Disagreement typically occurs because the methods measure different things: attention rollout captures information flow regardless of what is done with it, gradients measure local sensitivity, and perturbation methods measure actual counterfactual impact. Using multiple methods together provides a more complete picture than any single method alone.

7. Evaluation of Explanation Quality

How do we know if an explanation is "good"? Several metrics have been proposed to evaluate explanation quality, each capturing different desirable properties.

MetricWhat It MeasuresHow to Compute
Faithfulness (Sufficiency)Can the top-k tokens reproduce the prediction?Keep only top-k attributed tokens, measure prediction change
Faithfulness (Comprehensiveness)Do the top-k tokens account for the prediction?Remove top-k tokens, measure prediction drop
PlausibilityDo explanations match human intuition?Compare attributions to human annotation of important words
ConsistencyDo similar inputs get similar explanations?Measure attribution similarity for paraphrased inputs
SparsityHow concentrated is the attribution?Entropy or Gini coefficient of attribution distribution
# Faithfulness evaluation for attribution methods
def evaluate_faithfulness(
    model,
    tokenizer,
    text,
    attributions,
    k_values=[1, 3, 5],
):
    """
    Evaluate faithfulness of attributions using sufficiency and comprehensiveness.
    """
    inputs = tokenizer(text, return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    with torch.no_grad():
        baseline_logits = model(**inputs).logits[0, -1]
    predicted_id = baseline_logits.argmax()
    baseline_prob = torch.softmax(baseline_logits, dim=-1)[predicted_id].item()

    sorted_indices = np.argsort(attributions)[::-1]
    results = {}

    for k in k_values:
        top_k = sorted_indices[:k]

        # Sufficiency: keep only top-k tokens
        sufficient_ids = inputs["input_ids"].clone()
        mask = torch.ones(len(tokens), dtype=torch.bool)
        mask[list(top_k)] = False
        sufficient_ids[0, mask] = tokenizer.pad_token_id or 0

        with torch.no_grad():
            suf_logits = model(sufficient_ids).logits[0, -1]
        suf_prob = torch.softmax(suf_logits, dim=-1)[predicted_id].item()
        sufficiency = suf_prob / baseline_prob  # closer to 1 = better

        # Comprehensiveness: remove top-k tokens
        comp_ids = inputs["input_ids"].clone()
        for idx in top_k:
            comp_ids[0, idx] = tokenizer.pad_token_id or 0

        with torch.no_grad():
            comp_logits = model(comp_ids).logits[0, -1]
        comp_prob = torch.softmax(comp_logits, dim=-1)[predicted_id].item()
        comprehensiveness = 1 - (comp_prob / baseline_prob)  # closer to 1 = better

        results[f"k={k}"] = {
            "sufficiency": sufficiency,
            "comprehensiveness": comprehensiveness,
        }

    return results

📝 Section Quiz

1. What does attention rollout capture that raw attention does not?
Show Answer
Attention rollout captures indirect information flow across multiple layers. Raw attention shows only the direct attention at a single layer. If token A attends to token B at layer 3, and token B attended to token C at layer 1, then A has indirect access to C. Rollout traces these multi-hop paths by multiplying attention matrices across layers, accounting for the residual connection at each step.
2. Why does gradient-weighted attention (Attention x Gradient) combine two signals?
Show Answer
Attention weights show where the model looks, but high attention to a token does not mean that token is important for the prediction. The gradient tells us how sensitive the output is to changes in the attention pattern. Multiplying them highlights tokens that receive high attention AND that changing the attention would significantly affect the output. This eliminates "attention sinks" (tokens receiving high attention without influencing the prediction).
3. What is the "off-manifold" problem with perturbation-based explanations?
Show Answer
When we remove or replace a token to test its importance, we create an input that is unlike anything the model saw during training (e.g., a sentence with a random token in the middle of coherent text). The model's behavior on this unnatural input may not reflect what it would do if the token were genuinely absent. The model might be confused by the perturbation itself rather than responding to the absence of the information carried by that token.
4. What is the difference between faithfulness and plausibility in explanation evaluation?
Show Answer
Faithfulness measures whether the explanation accurately reflects the model's actual computation (does removing the highlighted tokens actually change the prediction?). Plausibility measures whether the explanation matches human intuition about what should be important. These can diverge: a model might make its prediction based on unexpected features (like punctuation patterns) that are faithful but implausible, while humans might expect certain keywords to be important even if the model does not rely on them.
5. When should you use multiple explanation methods rather than picking one?
Show Answer
Use multiple methods when: (1) the stakes are high and you need confidence in the explanation, (2) you want to distinguish between what the model attends to (rollout) versus what actually drives the prediction (IG, perturbation), (3) different methods disagree and you need to understand why, or (4) you need both quick insight (raw attention) and rigorous attribution (SHAP) for different audiences. Agreement across methods increases confidence; disagreement reveals the complexity of the model's decision process.

✅ Key Takeaways