Attention patterns and probing classifiers are the most accessible tools for understanding what transformers learn. Attention weights reveal which tokens the model considers when making predictions, while probing classifiers test what information (syntax, semantics, world knowledge) is encoded in hidden states at each layer. Combined with the logit lens, which projects intermediate representations into vocabulary space, these tools provide a layered view of how transformers transform input tokens into output predictions.
1. Attention Visualization
Every transformer layer computes attention weights: a matrix that specifies how much each token attends to every other token. Visualizing these weights provides an immediate, intuitive window into the model's computation. However, interpreting attention requires care: attention weights do not directly indicate which tokens are "important" for the prediction (as we discuss in Section 17.4).
# Extracting and visualizing attention patterns
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, output_attentions=True
)
text = "The cat sat on the mat because it was tired"
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# outputs.attentions is a tuple of (num_layers,) tensors
# Each tensor has shape (batch, num_heads, seq_len, seq_len)
attentions = outputs.attentions
print(f"Number of layers: {len(attentions)}")
print(f"Attention shape per layer: {attentions[0].shape}")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
def plot_attention_head(attention_matrix, tokens, layer, head):
"""Plot a single attention head as a heatmap."""
fig, ax = plt.subplots(figsize=(8, 6))
attn = attention_matrix[0, head].numpy() # (seq_len, seq_len)
im = ax.imshow(attn, cmap="Blues", vmin=0, vmax=1)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=8)
ax.set_yticklabels(tokens, fontsize=8)
ax.set_xlabel("Key (attending to)")
ax.set_ylabel("Query (attending from)")
ax.set_title(f"Layer {layer}, Head {head}")
plt.colorbar(im)
plt.tight_layout()
return fig
# Visualize a specific head
fig = plot_attention_head(attentions[5], tokens, layer=5, head=1)
plt.savefig("attention_head_L5_H1.png", dpi=150)
1.1 Common Attention Patterns
Research has identified several recurring attention patterns across transformer models. These patterns appear consistently regardless of model size, training data, or architecture variant, suggesting they represent fundamental computational primitives.
| Pattern | Layer Position | Function | Importance |
|---|---|---|---|
| Previous-token | Early (L0-L2) | Local context aggregation | Foundation for n-gram statistics |
| Induction heads | Early-mid (L1-L6) | Pattern copying from context | Core mechanism for in-context learning |
| Positional/sink | All layers | Attend to BOS or delimiters | Default when no specific pattern matches |
| Duplicate-token | Mid layers | Flag repeated tokens | Important for copy and repetition tasks |
| Semantic | Late layers | Attend to related meaning | Task-specific information retrieval |
Attention weights show where the model "looks" but not what it "sees." High attention to a token does not necessarily mean that token is important for the prediction. The value vectors determine what information is extracted, and the subsequent feed-forward layers further transform the representation. Use attention visualization as a starting point for investigation, not as definitive evidence of model reasoning.
2. Probing Classifiers
Probing classifiers provide a more rigorous way to test what information is encoded in a model's hidden representations. The idea is simple: extract hidden states from a specific layer, freeze them, and train a lightweight classifier on top to predict some property of interest (part of speech, syntactic dependency, semantic role, entity type). If the classifier succeeds, the property is encoded in the hidden states.
# Probing classifier for linguistic properties
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics import accuracy_score
import numpy as np
class LinearProbe(nn.Module):
"""A simple linear probe for testing representation content."""
def __init__(self, hidden_dim, num_classes):
super().__init__()
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, hidden_states):
return self.classifier(hidden_states)
class MLPProbe(nn.Module):
"""A nonlinear probe with one hidden layer."""
def __init__(self, hidden_dim, num_classes, probe_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden_dim, probe_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(probe_dim, num_classes),
)
def forward(self, hidden_states):
return self.net(hidden_states)
def extract_hidden_states(model, tokenizer, texts, layer_idx):
"""Extract hidden states from a specific layer."""
model.eval()
all_hidden = []
for text in texts:
inputs = tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# Get hidden states from the specified layer
# Shape: (1, seq_len, hidden_dim)
hidden = outputs.hidden_states[layer_idx]
all_hidden.append(hidden.squeeze(0))
return all_hidden
def train_probe(
hidden_states, # list of (seq_len, hidden_dim) tensors
labels, # list of (seq_len,) label tensors
num_classes,
probe_type="linear",
epochs=10,
lr=1e-3,
):
"""Train a probing classifier on frozen hidden states."""
# Flatten all tokens into a single dataset
X = torch.cat(hidden_states, dim=0)
y = torch.cat(labels, dim=0)
hidden_dim = X.shape[1]
if probe_type == "linear":
probe = LinearProbe(hidden_dim, num_classes)
else:
probe = MLPProbe(hidden_dim, num_classes)
optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=256, shuffle=True)
for epoch in range(epochs):
total_loss = 0
for batch_x, batch_y in loader:
logits = probe(batch_x)
loss = criterion(logits, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch+1}: loss={total_loss/len(loader):.4f}")
return probe
# Example: Probe for part-of-speech tags across layers
model = AutoModel.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# For each layer, train a probe and measure accuracy
layer_accuracies = {}
for layer in range(model.config.num_hidden_layers + 1):
hidden = extract_hidden_states(model, tokenizer, train_texts, layer)
probe = train_probe(hidden, pos_labels, num_pos_tags, "linear")
acc = evaluate_probe(probe, val_hidden, val_labels)
layer_accuracies[layer] = acc
print(f"Layer {layer}: POS accuracy = {acc:.3f}")
2.1 Control Tasks
A common criticism of probing is that a powerful enough probe might learn the task itself rather than reflecting what the model has learned. Control tasks (Hewitt and Liang, 2019) address this by training the same probe on a random labeling of the data. If the probe achieves high accuracy on both the real task and the control task, the probe is too powerful, and the result is not meaningful.
The selectivity of a probe (real accuracy minus control accuracy) is a better measure than raw accuracy. A linear probe that achieves 90% on POS tagging but 30% on random labels has selectivity of 60%, indicating the representation genuinely encodes POS information. An MLP probe that achieves 95% on POS tagging but 85% on random labels has selectivity of only 10%, suggesting the probe itself is doing most of the work.
# Control task for validating probe results
import random
def run_probing_experiment_with_control(
model, tokenizer, texts, labels, num_classes, layer_idx
):
"""Run probing with control task to measure selectivity."""
hidden = extract_hidden_states(model, tokenizer, texts, layer_idx)
# Real task probe
real_probe = train_probe(hidden, labels, num_classes, "linear")
real_acc = evaluate_probe(real_probe, val_hidden, val_labels)
# Control task: shuffle labels to create random assignment
control_labels = [
torch.randint(0, num_classes, label.shape)
for label in labels
]
control_probe = train_probe(hidden, control_labels, num_classes, "linear")
control_acc = evaluate_probe(control_probe, val_hidden, val_control_labels)
selectivity = real_acc - control_acc
return {
"real_accuracy": real_acc,
"control_accuracy": control_acc,
"selectivity": selectivity,
"meaningful": selectivity > 0.1, # threshold
}
3. The Logit Lens
The logit lens (nostalgebraist, 2020) is a technique for inspecting what a transformer "thinks" at each intermediate layer. The idea is to take the hidden state at any layer and project it through the model's final unembedding matrix (the same matrix used to produce the output logits). This reveals the model's current "best guess" for the next token at each point in the computation.
# Logit Lens Implementation
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
def logit_lens(model, tokenizer, text, top_k=5):
"""
Apply the logit lens to see what the model predicts
at each intermediate layer.
"""
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states # (num_layers+1,) tuple
# Get the unembedding matrix
unembed = model.lm_head.weight # (vocab_size, hidden_dim)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
last_pos = len(tokens) - 1 # predict next token after last
print(f"Input: {text}")
print(f"Predicting token after: '{tokens[last_pos]}'")
print("-" * 50)
for layer_idx, hidden in enumerate(hidden_states):
# Project hidden state through unembedding
# hidden shape: (1, seq_len, hidden_dim)
logits = hidden[0, last_pos] @ unembed.T # (vocab_size,)
probs = F.softmax(logits, dim=-1)
top_probs, top_ids = probs.topk(top_k)
top_tokens = [tokenizer.decode(tid) for tid in top_ids]
top_str = ", ".join(
f"'{t}' ({p:.3f})" for t, p in zip(top_tokens, top_probs)
)
print(f"Layer {layer_idx:2d}: {top_str}")
# Run logit lens on GPT-2
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
logit_lens(model, tokenizer, "The Eiffel Tower is in")
3.1 The Tuned Lens
The tuned lens (Belrose et al., 2023) improves on the logit lens by training a learned affine transformation for each layer. The raw logit lens assumes that intermediate representations are approximately in the same space as the final layer, which is only roughly true. The tuned lens trains a small per-layer probe to account for the differences in representation spaces across layers.
The tuned lens provides cleaner, more interpretable results than the raw logit lens, especially in early layers where representations are furthest from the output space. The training cost is minimal (a single linear layer per transformer layer), and pre-trained tuned lens parameters are available for popular models via the tuned-lens Python package.
# Using the tuned lens package
# pip install tuned-lens
from tuned_lens import TunedLens
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Load pre-trained tuned lens for GPT-2
tuned = TunedLens.from_model_and_pretrained(model)
text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states
# Apply tuned lens at each layer
for layer_idx in range(len(hidden_states) - 1):
# Tuned lens applies a learned affine transform
logits = tuned(hidden_states[layer_idx], layer_idx)
probs = F.softmax(logits[0, -1], dim=-1)
top_token = tokenizer.decode(probs.argmax())
top_prob = probs.max().item()
print(f"Layer {layer_idx:2d}: '{top_token}' (p={top_prob:.3f})")
📝 Section Quiz
Show Answer
Show Answer
Show Answer
Show Answer
Show Answer
✅ Key Takeaways
- Attention visualization reveals recurring patterns (previous-token, induction, positional, semantic heads) that represent fundamental computational primitives in transformers.
- Attention weights show where the model looks but not what it learns; use them as starting points, not definitive explanations.
- Probing classifiers test what information is encoded in hidden states. Control tasks are essential to validate that probes measure representation content rather than probe capacity.
- The logit lens projects intermediate hidden states into vocabulary space, revealing how predictions are refined incrementally across layers.
- The tuned lens improves on the logit lens by learning per-layer affine transformations that account for differences in representation spaces.
- Together, these tools provide a layered view of transformer computation: what the model attends to, what it encodes, and what it predicts at each stage.