Mechanistic interpretability aims to fully reverse-engineer neural networks into human-understandable components. Rather than treating models as black boxes and probing their behavior from the outside, this research program seeks to identify the fundamental features that neurons compute, the circuits that connect them, and the algorithms these circuits implement. The approach draws from reverse-engineering traditions in biology and electrical engineering, treating a trained neural network as an artifact to be disassembled and understood. Sparse autoencoders (SAEs) have emerged as the key tool for extracting interpretable features from the superposed representations that transformers learn.
1. The Residual Stream View
Mechanistic interpretability starts with a conceptual reframing of the transformer architecture. Instead of viewing transformers as a sequence of layers that process data in order, the "residual stream" view treats the residual connection as a central communication channel. Each attention head and MLP layer reads from this stream, performs a computation, and writes its result back to the stream by adding it to the running total.
The residual stream view has a powerful implication: because all components communicate by adding to the same stream, their contributions can be analyzed independently. The output of the entire network is the sum of the embedding, every attention head output, and every MLP output. This linearity (in the residual connections) makes it possible to attribute the model's behavior to specific components by measuring the effect of removing or modifying each contribution.
2. Superposition and Polysemanticity
A central challenge in interpreting neural networks is that individual neurons are often polysemantic: a single neuron may activate for multiple unrelated concepts. For example, a neuron might fire for both "legal terminology" and "the color blue." This happens because of superposition: the model encodes more features than it has neurons by representing features as directions in activation space rather than individual neurons.
Superposition is an efficient compression strategy. If a model has 4096 neurons per layer but needs to track 100,000 features, it can represent each feature as a direction in the 4096-dimensional space. As long as features rarely co-occur, they can share neurons with minimal interference. This means the "true" features of the model are not individual neurons but rather directions in activation space, which are linear combinations of neurons.
3. Sparse Autoencoders (SAEs)
Sparse autoencoders are the primary tool for disentangling superposed representations. An SAE takes the model's activations as input, encodes them into a much higher-dimensional but sparse representation, and then decodes back to the original activation space. The key constraint is sparsity: only a small fraction of the SAE's latent dimensions are active for any given input. Each active dimension corresponds to a single interpretable feature.
# Sparse Autoencoder for Mechanistic Interpretability
import torch
import torch.nn as nn
import torch.nn.functional as F
class SparseAutoencoder(nn.Module):
"""
Sparse Autoencoder for extracting interpretable features
from transformer activations.
Architecture:
- Encoder: project d_model -> d_sae (expansion)
- ReLU sparsity
- Decoder: project d_sae -> d_model (reconstruction)
The expansion factor (d_sae / d_model) is typically 4x to 64x.
"""
def __init__(self, d_model, d_sae, l1_coeff=5e-3):
super().__init__()
self.d_model = d_model
self.d_sae = d_sae
self.l1_coeff = l1_coeff
# Encoder and decoder
self.encoder = nn.Linear(d_model, d_sae, bias=True)
self.decoder = nn.Linear(d_sae, d_model, bias=True)
# Initialize decoder weights to unit norm
with torch.no_grad():
self.decoder.weight.data = F.normalize(
self.decoder.weight.data, dim=0
)
def encode(self, x):
"""Encode activations into sparse feature space."""
# Subtract decoder bias for centering
x_centered = x - self.decoder.bias
# Encode and apply ReLU for sparsity
z = F.relu(self.encoder(x_centered))
return z
def decode(self, z):
"""Reconstruct activations from sparse features."""
return self.decoder(z)
def forward(self, x):
z = self.encode(x) # sparse features
x_hat = self.decode(z) # reconstruction
return x_hat, z
def loss(self, x):
"""Combined reconstruction + sparsity loss."""
x_hat, z = self.forward(x)
# Reconstruction loss (MSE)
recon_loss = F.mse_loss(x_hat, x)
# Sparsity loss (L1 on activations)
l1_loss = z.abs().mean()
return recon_loss + self.l1_coeff * l1_loss, {
"reconstruction": recon_loss.item(),
"sparsity": l1_loss.item(),
"alive_features": (z > 0).any(dim=0).sum().item(),
"avg_active": (z > 0).float().mean().item(),
}
# Training an SAE on GPT-2 MLP activations
def train_sae_on_model(
model_name="gpt2",
layer_idx=6,
component="mlp",
expansion_factor=8,
num_tokens=10_000_000,
batch_size=4096,
lr=3e-4,
):
"""Train a sparse autoencoder on transformer activations."""
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
d_model = model.config.n_embd
d_sae = d_model * expansion_factor
sae = SparseAutoencoder(d_model, d_sae)
optimizer = torch.optim.Adam(sae.parameters(), lr=lr)
# Collect activations using hooks
activations_buffer = []
def hook_fn(module, input, output):
activations_buffer.append(output.detach())
# Register hook on the target layer
if component == "mlp":
handle = model.transformer.h[layer_idx].mlp.register_forward_hook(hook_fn)
else:
handle = model.transformer.h[layer_idx].attn.register_forward_hook(hook_fn)
# Training loop (simplified)
for step in range(num_tokens // batch_size):
# Get a batch of activations
batch_acts = get_activation_batch(
model, tokenizer, batch_size
)
loss, metrics = sae.loss(batch_acts)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Normalize decoder weights after each step
with torch.no_grad():
sae.decoder.weight.data = F.normalize(
sae.decoder.weight.data, dim=0
)
if step % 100 == 0:
print(
f"Step {step}: recon={metrics['reconstruction']:.4f}, "
f"l1={metrics['sparsity']:.4f}, "
f"alive={metrics['alive_features']}/{d_sae}"
)
handle.remove()
return sae
A major practical challenge with SAEs is "dead features": latent dimensions that never activate after initialization. With expansion factors of 16x or higher, 20% to 50% of features may die during training. Techniques like resampling dead features, using a ghost gradient for the encoder bias, or TopK activation functions (instead of ReLU) help mitigate this issue. Always monitor the fraction of alive features during SAE training.
4. Activation Patching
Activation patching (also called causal tracing or interchange intervention) is a technique for identifying which components of a model are responsible for a specific behavior. The idea is to run the model on two inputs (a "clean" input that triggers the behavior and a "corrupted" input that does not), and then selectively replace activations from one run with activations from the other. If replacing a specific component's activation restores the original behavior, that component is causally important.
# Activation Patching with TransformerLens
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch
# Load model with TransformerLens
model = HookedTransformer.from_pretrained("gpt2-small")
# Example: Which components know that "The Eiffel Tower is in" -> "Paris"?
clean_prompt = "The Eiffel Tower is in"
corrupted_prompt = "The Colosseum is in" # different answer: "Rome"
# Get clean logits for the target token
clean_logits, clean_cache = model.run_with_cache(clean_prompt)
target_token = model.to_single_token(" Paris")
clean_logit_diff = (
clean_logits[0, -1, target_token]
- clean_logits[0, -1, model.to_single_token(" Rome")]
).item()
# Get corrupted cache
_, corrupted_cache = model.run_with_cache(corrupted_prompt)
# Patch each attention head and measure the effect
results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
# Define hook that patches this head's output
def patch_hook(activation, hook, layer=layer, head=head):
# Replace corrupted activation with clean activation
activation[:, :, head, :] = clean_cache[
hook.name
][:, :, head, :]
return activation
# Run corrupted input with this one head patched
hook_name = f"blocks.{layer}.attn.hook_z"
patched_logits = model.run_with_hooks(
corrupted_prompt,
fwd_hooks=[(hook_name, patch_hook)],
)
# Measure recovery of clean behavior
patched_diff = (
patched_logits[0, -1, target_token]
- patched_logits[0, -1, model.to_single_token(" Rome")]
).item()
results[layer, head] = patched_diff
print("Top 5 most important heads for 'Eiffel Tower -> Paris':")
flat = results.flatten()
top_indices = flat.argsort(descending=True)[:5]
for idx in top_indices:
layer = idx // model.cfg.n_heads
head = idx % model.cfg.n_heads
print(f" L{layer}H{head}: logit diff = {flat[idx]:.3f}")
5. TransformerLens and nnsight
Two primary libraries support mechanistic interpretability research on transformers. TransformerLens provides a re-implementation of common models with full hook access at every computational step. nnsight offers a different approach, allowing intervention on any PyTorch model without re-implementation.
| Feature | TransformerLens | nnsight |
|---|---|---|
| Approach | Re-implements models with hooks at every step | Wraps existing PyTorch models with proxy access |
| Supported models | GPT-2, GPT-Neo, Pythia, Llama, Mistral, Gemma | Any PyTorch model |
| Hook granularity | Every attention sub-computation (Q, K, V, patterns, z) | Any module input/output |
| Caching | Built-in activation caching | Proxy-based lazy evaluation |
| Best for | Detailed mechanistic analysis of supported models | Quick experiments on any architecture |
| Learning curve | Moderate (custom API) | Lower (familiar PyTorch patterns) |
# nnsight: Intervening on any PyTorch model
from nnsight import LanguageModel
# Wrap any HuggingFace model
model = LanguageModel("gpt2", device_map="auto")
# Use the tracing context to inspect and modify activations
with model.trace("The cat sat on the") as tracer:
# Access any module's output
layer_5_output = model.transformer.h[5].output[0]
# Save it for inspection
layer_5_output.save()
# You can also modify activations in-place
# model.transformer.h[5].mlp.output[:] *= 0 # ablate MLP
# Access saved activations after the trace
print(f"Layer 5 output shape: {layer_5_output.value.shape}")
# Activation patching with nnsight
clean_text = "The Eiffel Tower is in"
corrupt_text = "The Colosseum is in"
# Get clean activations
with model.trace(clean_text) as tracer:
clean_resid = model.transformer.h[8].output[0].save()
# Patch corrupted run with clean activations at layer 8
with model.trace(corrupt_text) as tracer:
model.transformer.h[8].output[0][:] = clean_resid.value
patched_logits = model.lm_head.output.save()
print(f"Patched prediction: {patched_logits.value[0, -1].argmax()}")
Anthropic's interpretability research program has produced several landmark results using SAEs at scale. Their work on Claude models has identified millions of interpretable features, including features for specific concepts (Golden Gate Bridge, code bugs, deception), safety-relevant behaviors (refusal, harmful content detection), and abstract reasoning patterns. This demonstrates that SAE-based mechanistic interpretability can scale to production-sized models.
Mechanistic interpretability is not just an academic exercise. Practical applications include: (1) understanding why a model produces specific outputs, enabling targeted debugging; (2) identifying and removing undesirable behaviors like deception or sycophancy; (3) verifying that safety training has actually modified the model's internal computations rather than just its surface behavior; and (4) steering model behavior by amplifying or suppressing specific features at inference time.
📝 Section Quiz
Show Answer
Show Answer
Show Answer
Show Answer
Show Answer
✅ Key Takeaways
- The residual stream view frames transformers as a shared communication channel where each component makes additive contributions, enabling component-level analysis.
- Superposition allows models to represent far more features than neurons by encoding features as directions in activation space, creating polysemantic neurons.
- Sparse autoencoders (SAEs) decompose superposed representations into interpretable features by projecting into a higher-dimensional sparse space.
- Activation patching provides causal (not just correlational) evidence for which model components are responsible for specific behaviors.
- TransformerLens and nnsight are the two primary tools for mechanistic analysis, offering different tradeoffs between depth of access and model coverage.
- Mechanistic interpretability has practical applications in debugging, safety verification, and behavior steering beyond pure research interest.