Module 17 · Section 17.1

Attention Analysis & Probing

Visualizing attention patterns, probing hidden representations, and using the logit lens to inspect how transformers process information
★ Big Picture

Attention patterns and probing classifiers are the most accessible tools for understanding what transformers learn. Attention weights reveal which tokens the model considers when making predictions, while probing classifiers test what information (syntax, semantics, world knowledge) is encoded in hidden states at each layer. Combined with the logit lens, which projects intermediate representations into vocabulary space, these tools provide a layered view of how transformers transform input tokens into output predictions.

1. Attention Visualization

Every transformer layer computes attention weights: a matrix that specifies how much each token attends to every other token. Visualizing these weights provides an immediate, intuitive window into the model's computation. However, interpreting attention requires care: attention weights do not directly indicate which tokens are "important" for the prediction (as we discuss in Section 17.4).

# Extracting and visualizing attention patterns
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, output_attentions=True
)

text = "The cat sat on the mat because it was tired"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# outputs.attentions is a tuple of (num_layers,) tensors
# Each tensor has shape (batch, num_heads, seq_len, seq_len)
attentions = outputs.attentions
print(f"Number of layers: {len(attentions)}")
print(f"Attention shape per layer: {attentions[0].shape}")

tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

def plot_attention_head(attention_matrix, tokens, layer, head):
    """Plot a single attention head as a heatmap."""
    fig, ax = plt.subplots(figsize=(8, 6))
    attn = attention_matrix[0, head].numpy()  # (seq_len, seq_len)
    im = ax.imshow(attn, cmap="Blues", vmin=0, vmax=1)
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=8)
    ax.set_yticklabels(tokens, fontsize=8)
    ax.set_xlabel("Key (attending to)")
    ax.set_ylabel("Query (attending from)")
    ax.set_title(f"Layer {layer}, Head {head}")
    plt.colorbar(im)
    plt.tight_layout()
    return fig

# Visualize a specific head
fig = plot_attention_head(attentions[5], tokens, layer=5, head=1)
plt.savefig("attention_head_L5_H1.png", dpi=150)

1.1 Common Attention Patterns

Research has identified several recurring attention patterns across transformer models. These patterns appear consistently regardless of model size, training data, or architecture variant, suggesting they represent fundamental computational primitives.

Previous Token Induction Head copies from previous occurrence Positional Head attends to BOS/first token Semantic Head attends to related tokens Each token attends to the immediately preceding token. Early layers. Copies patterns from earlier context. Enables in-context learning. Attends to fixed positions (BOS, punctuation). Serves as "no-op" sink. Attends to semantically related tokens. Later layers, task-specific.
Figure 17.1: Common attention head types observed across transformer models. Each pattern represents a distinct computational role.
PatternLayer PositionFunctionImportance
Previous-tokenEarly (L0-L2)Local context aggregationFoundation for n-gram statistics
Induction headsEarly-mid (L1-L6)Pattern copying from contextCore mechanism for in-context learning
Positional/sinkAll layersAttend to BOS or delimitersDefault when no specific pattern matches
Duplicate-tokenMid layersFlag repeated tokensImportant for copy and repetition tasks
SemanticLate layersAttend to related meaningTask-specific information retrieval
⚠ Warning

Attention weights show where the model "looks" but not what it "sees." High attention to a token does not necessarily mean that token is important for the prediction. The value vectors determine what information is extracted, and the subsequent feed-forward layers further transform the representation. Use attention visualization as a starting point for investigation, not as definitive evidence of model reasoning.

2. Probing Classifiers

Probing classifiers provide a more rigorous way to test what information is encoded in a model's hidden representations. The idea is simple: extract hidden states from a specific layer, freeze them, and train a lightweight classifier on top to predict some property of interest (part of speech, syntactic dependency, semantic role, entity type). If the classifier succeeds, the property is encoded in the hidden states.

# Probing classifier for linguistic properties
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics import accuracy_score
import numpy as np

class LinearProbe(nn.Module):
    """A simple linear probe for testing representation content."""
    def __init__(self, hidden_dim, num_classes):
        super().__init__()
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, hidden_states):
        return self.classifier(hidden_states)

class MLPProbe(nn.Module):
    """A nonlinear probe with one hidden layer."""
    def __init__(self, hidden_dim, num_classes, probe_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, probe_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(probe_dim, num_classes),
        )

    def forward(self, hidden_states):
        return self.net(hidden_states)

def extract_hidden_states(model, tokenizer, texts, layer_idx):
    """Extract hidden states from a specific layer."""
    model.eval()
    all_hidden = []

    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        # Get hidden states from the specified layer
        # Shape: (1, seq_len, hidden_dim)
        hidden = outputs.hidden_states[layer_idx]
        all_hidden.append(hidden.squeeze(0))

    return all_hidden

def train_probe(
    hidden_states,     # list of (seq_len, hidden_dim) tensors
    labels,            # list of (seq_len,) label tensors
    num_classes,
    probe_type="linear",
    epochs=10,
    lr=1e-3,
):
    """Train a probing classifier on frozen hidden states."""
    # Flatten all tokens into a single dataset
    X = torch.cat(hidden_states, dim=0)
    y = torch.cat(labels, dim=0)

    hidden_dim = X.shape[1]
    if probe_type == "linear":
        probe = LinearProbe(hidden_dim, num_classes)
    else:
        probe = MLPProbe(hidden_dim, num_classes)

    optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=256, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for batch_x, batch_y in loader:
            logits = probe(batch_x)
            loss = criterion(logits, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}: loss={total_loss/len(loader):.4f}")

    return probe

# Example: Probe for part-of-speech tags across layers
model = AutoModel.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# For each layer, train a probe and measure accuracy
layer_accuracies = {}
for layer in range(model.config.num_hidden_layers + 1):
    hidden = extract_hidden_states(model, tokenizer, train_texts, layer)
    probe = train_probe(hidden, pos_labels, num_pos_tags, "linear")
    acc = evaluate_probe(probe, val_hidden, val_labels)
    layer_accuracies[layer] = acc
    print(f"Layer {layer}: POS accuracy = {acc:.3f}")

2.1 Control Tasks

A common criticism of probing is that a powerful enough probe might learn the task itself rather than reflecting what the model has learned. Control tasks (Hewitt and Liang, 2019) address this by training the same probe on a random labeling of the data. If the probe achieves high accuracy on both the real task and the control task, the probe is too powerful, and the result is not meaningful.

💡 Key Insight

The selectivity of a probe (real accuracy minus control accuracy) is a better measure than raw accuracy. A linear probe that achieves 90% on POS tagging but 30% on random labels has selectivity of 60%, indicating the representation genuinely encodes POS information. An MLP probe that achieves 95% on POS tagging but 85% on random labels has selectivity of only 10%, suggesting the probe itself is doing most of the work.

# Control task for validating probe results
import random

def run_probing_experiment_with_control(
    model, tokenizer, texts, labels, num_classes, layer_idx
):
    """Run probing with control task to measure selectivity."""

    hidden = extract_hidden_states(model, tokenizer, texts, layer_idx)

    # Real task probe
    real_probe = train_probe(hidden, labels, num_classes, "linear")
    real_acc = evaluate_probe(real_probe, val_hidden, val_labels)

    # Control task: shuffle labels to create random assignment
    control_labels = [
        torch.randint(0, num_classes, label.shape)
        for label in labels
    ]
    control_probe = train_probe(hidden, control_labels, num_classes, "linear")
    control_acc = evaluate_probe(control_probe, val_hidden, val_control_labels)

    selectivity = real_acc - control_acc

    return {
        "real_accuracy": real_acc,
        "control_accuracy": control_acc,
        "selectivity": selectivity,
        "meaningful": selectivity > 0.1,  # threshold
    }

3. The Logit Lens

The logit lens (nostalgebraist, 2020) is a technique for inspecting what a transformer "thinks" at each intermediate layer. The idea is to take the hidden state at any layer and project it through the model's final unembedding matrix (the same matrix used to produce the output logits). This reveals the model's current "best guess" for the next token at each point in the computation.

Layer 0 (embeddings) Layer 4 Layer 8 Layer 11 (final) GPT-2 Hidden States "The Eiffel Tower is in" Unembedding W_U Top: "the" (noisy) Top: "France" (weak) Top: "Paris" (strong) Top: "Paris" (final) Logit Lens Predictions
Figure 17.2: The logit lens projects hidden states from each layer through the unembedding matrix. Earlier layers produce noisy predictions that sharpen as computation progresses.
# Logit Lens Implementation
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

def logit_lens(model, tokenizer, text, top_k=5):
    """
    Apply the logit lens to see what the model predicts
    at each intermediate layer.
    """
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    hidden_states = outputs.hidden_states  # (num_layers+1,) tuple
    # Get the unembedding matrix
    unembed = model.lm_head.weight  # (vocab_size, hidden_dim)

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    last_pos = len(tokens) - 1  # predict next token after last

    print(f"Input: {text}")
    print(f"Predicting token after: '{tokens[last_pos]}'")
    print("-" * 50)

    for layer_idx, hidden in enumerate(hidden_states):
        # Project hidden state through unembedding
        # hidden shape: (1, seq_len, hidden_dim)
        logits = hidden[0, last_pos] @ unembed.T  # (vocab_size,)
        probs = F.softmax(logits, dim=-1)

        top_probs, top_ids = probs.topk(top_k)
        top_tokens = [tokenizer.decode(tid) for tid in top_ids]

        top_str = ", ".join(
            f"'{t}' ({p:.3f})" for t, p in zip(top_tokens, top_probs)
        )
        print(f"Layer {layer_idx:2d}: {top_str}")

# Run logit lens on GPT-2
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
logit_lens(model, tokenizer, "The Eiffel Tower is in")
Input: The Eiffel Tower is in Predicting token after: 'in' -------------------------------------------------- Layer 0: 'the' (0.042), 'a' (0.031), 'of' (0.028) Layer 3: 'the' (0.089), 'France' (0.045), 'a' (0.033) Layer 6: 'Paris' (0.082), 'France' (0.071), 'the' (0.055) Layer 9: 'Paris' (0.215), 'France' (0.098), 'the' (0.044) Layer 11: 'Paris' (0.412), 'France' (0.087), 'the' (0.031)

3.1 The Tuned Lens

The tuned lens (Belrose et al., 2023) improves on the logit lens by training a learned affine transformation for each layer. The raw logit lens assumes that intermediate representations are approximately in the same space as the final layer, which is only roughly true. The tuned lens trains a small per-layer probe to account for the differences in representation spaces across layers.

📝 Note

The tuned lens provides cleaner, more interpretable results than the raw logit lens, especially in early layers where representations are furthest from the output space. The training cost is minimal (a single linear layer per transformer layer), and pre-trained tuned lens parameters are available for popular models via the tuned-lens Python package.

# Using the tuned lens package
# pip install tuned-lens
from tuned_lens import TunedLens
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Load pre-trained tuned lens for GPT-2
tuned = TunedLens.from_model_and_pretrained(model)

text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

hidden_states = outputs.hidden_states

# Apply tuned lens at each layer
for layer_idx in range(len(hidden_states) - 1):
    # Tuned lens applies a learned affine transform
    logits = tuned(hidden_states[layer_idx], layer_idx)
    probs = F.softmax(logits[0, -1], dim=-1)
    top_token = tokenizer.decode(probs.argmax())
    top_prob = probs.max().item()
    print(f"Layer {layer_idx:2d}: '{top_token}' (p={top_prob:.3f})")

📝 Section Quiz

1. Why should attention weights not be interpreted as "importance" scores?
Show Answer
Attention weights indicate how much information flows between token positions, but they do not account for the value vectors (what information is actually transferred) or the subsequent feed-forward layer transformations. High attention to a token means the model "looks" at it, not that it is important for the prediction. Attention can also serve functional roles (like "no-op" sink heads attending to BOS) unrelated to semantic importance.
2. What is the purpose of a control task in probing experiments?
Show Answer
A control task trains the same probe architecture on randomly shuffled labels. It measures how much accuracy comes from the probe's own capacity (memorizing the mapping) versus the information actually encoded in the representations. The selectivity (real accuracy minus control accuracy) gives a more honest measure of representation quality. If both real and control accuracy are high, the probe is too powerful.
3. What does the logit lens reveal about how transformers process information?
Show Answer
The logit lens shows that transformers build up their predictions incrementally across layers. Early layers produce noisy, uncertain predictions. Middle layers begin to converge on the correct answer. Late layers refine the final prediction. This reveals that transformer computation is a gradual refinement process, not a single-step computation. The "residual stream" view (Section 17.2) formalizes this as iterative updates to a shared representation.
4. What are induction heads, and why are they important?
Show Answer
Induction heads are attention heads that implement a copying mechanism: when they see a pattern [A][B]...[A], they predict [B] will follow. They are a two-head circuit (previous-token head + induction head) and are believed to be the primary mechanism underlying in-context learning. They enable transformers to recognize and continue patterns from the context window without any gradient updates.
5. How does the tuned lens improve on the standard logit lens?
Show Answer
The standard logit lens assumes all layers share approximately the same representation space, which is only roughly true. The tuned lens trains a small affine transformation per layer to map each layer's representation into the output space more accurately. This produces cleaner predictions, especially in early layers where the representation space differs most from the final layer. The per-layer affine transform has minimal parameters and can be pre-trained offline.

✅ Key Takeaways