Module 17 · Section 17.2

Mechanistic Interpretability

Reverse-engineering neural networks: circuits, features, sparse autoencoders, and the quest to understand what models compute
★ Big Picture

Mechanistic interpretability aims to fully reverse-engineer neural networks into human-understandable components. Rather than treating models as black boxes and probing their behavior from the outside, this research program seeks to identify the fundamental features that neurons compute, the circuits that connect them, and the algorithms these circuits implement. The approach draws from reverse-engineering traditions in biology and electrical engineering, treating a trained neural network as an artifact to be disassembled and understood. Sparse autoencoders (SAEs) have emerged as the key tool for extracting interpretable features from the superposed representations that transformers learn.

1. The Residual Stream View

Mechanistic interpretability starts with a conceptual reframing of the transformer architecture. Instead of viewing transformers as a sequence of layers that process data in order, the "residual stream" view treats the residual connection as a central communication channel. Each attention head and MLP layer reads from this stream, performs a computation, and writes its result back to the stream by adding it to the running total.

Residual Stream Embed Attn 0.0 read + write Attn 0.1 MLP 0 Attn 1.0 MLP 1 ... Each component reads from and writes to the shared residual stream via addition
Figure 17.3: The residual stream view of a transformer. Attention heads and MLP layers read from and write to a shared stream, accumulating their contributions through addition.
💡 Key Insight

The residual stream view has a powerful implication: because all components communicate by adding to the same stream, their contributions can be analyzed independently. The output of the entire network is the sum of the embedding, every attention head output, and every MLP output. This linearity (in the residual connections) makes it possible to attribute the model's behavior to specific components by measuring the effect of removing or modifying each contribution.

2. Superposition and Polysemanticity

A central challenge in interpreting neural networks is that individual neurons are often polysemantic: a single neuron may activate for multiple unrelated concepts. For example, a neuron might fire for both "legal terminology" and "the color blue." This happens because of superposition: the model encodes more features than it has neurons by representing features as directions in activation space rather than individual neurons.

Superposition is an efficient compression strategy. If a model has 4096 neurons per layer but needs to track 100,000 features, it can represent each feature as a direction in the 4096-dimensional space. As long as features rarely co-occur, they can share neurons with minimal interference. This means the "true" features of the model are not individual neurons but rather directions in activation space, which are linear combinations of neurons.

No Superposition 1 feature = 1 neuron cat dog blue 3 features, 3 neurons Clean, interpretable But wasteful if features are sparse Superposition Many features share neurons cat + legal dog + blue 5+ features, 2 neurons Efficient but hard to interpret SAEs decompose these into features
Figure 17.4: Without superposition, each neuron represents one feature (clean but limited). With superposition, multiple features share neurons, creating polysemantic neurons that are harder to interpret.

3. Sparse Autoencoders (SAEs)

Sparse autoencoders are the primary tool for disentangling superposed representations. An SAE takes the model's activations as input, encodes them into a much higher-dimensional but sparse representation, and then decodes back to the original activation space. The key constraint is sparsity: only a small fraction of the SAE's latent dimensions are active for any given input. Each active dimension corresponds to a single interpretable feature.

# Sparse Autoencoder for Mechanistic Interpretability
import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseAutoencoder(nn.Module):
    """
    Sparse Autoencoder for extracting interpretable features
    from transformer activations.

    Architecture:
    - Encoder: project d_model -> d_sae (expansion)
    - ReLU sparsity
    - Decoder: project d_sae -> d_model (reconstruction)

    The expansion factor (d_sae / d_model) is typically 4x to 64x.
    """

    def __init__(self, d_model, d_sae, l1_coeff=5e-3):
        super().__init__()
        self.d_model = d_model
        self.d_sae = d_sae
        self.l1_coeff = l1_coeff

        # Encoder and decoder
        self.encoder = nn.Linear(d_model, d_sae, bias=True)
        self.decoder = nn.Linear(d_sae, d_model, bias=True)

        # Initialize decoder weights to unit norm
        with torch.no_grad():
            self.decoder.weight.data = F.normalize(
                self.decoder.weight.data, dim=0
            )

    def encode(self, x):
        """Encode activations into sparse feature space."""
        # Subtract decoder bias for centering
        x_centered = x - self.decoder.bias
        # Encode and apply ReLU for sparsity
        z = F.relu(self.encoder(x_centered))
        return z

    def decode(self, z):
        """Reconstruct activations from sparse features."""
        return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)  # sparse features
        x_hat = self.decode(z)  # reconstruction
        return x_hat, z

    def loss(self, x):
        """Combined reconstruction + sparsity loss."""
        x_hat, z = self.forward(x)

        # Reconstruction loss (MSE)
        recon_loss = F.mse_loss(x_hat, x)

        # Sparsity loss (L1 on activations)
        l1_loss = z.abs().mean()

        return recon_loss + self.l1_coeff * l1_loss, {
            "reconstruction": recon_loss.item(),
            "sparsity": l1_loss.item(),
            "alive_features": (z > 0).any(dim=0).sum().item(),
            "avg_active": (z > 0).float().mean().item(),
        }

# Training an SAE on GPT-2 MLP activations
def train_sae_on_model(
    model_name="gpt2",
    layer_idx=6,
    component="mlp",
    expansion_factor=8,
    num_tokens=10_000_000,
    batch_size=4096,
    lr=3e-4,
):
    """Train a sparse autoencoder on transformer activations."""
    from transformers import AutoModelForCausalLM, AutoTokenizer

    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    d_model = model.config.n_embd
    d_sae = d_model * expansion_factor

    sae = SparseAutoencoder(d_model, d_sae)
    optimizer = torch.optim.Adam(sae.parameters(), lr=lr)

    # Collect activations using hooks
    activations_buffer = []

    def hook_fn(module, input, output):
        activations_buffer.append(output.detach())

    # Register hook on the target layer
    if component == "mlp":
        handle = model.transformer.h[layer_idx].mlp.register_forward_hook(hook_fn)
    else:
        handle = model.transformer.h[layer_idx].attn.register_forward_hook(hook_fn)

    # Training loop (simplified)
    for step in range(num_tokens // batch_size):
        # Get a batch of activations
        batch_acts = get_activation_batch(
            model, tokenizer, batch_size
        )

        loss, metrics = sae.loss(batch_acts)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Normalize decoder weights after each step
        with torch.no_grad():
            sae.decoder.weight.data = F.normalize(
                sae.decoder.weight.data, dim=0
            )

        if step % 100 == 0:
            print(
                f"Step {step}: recon={metrics['reconstruction']:.4f}, "
                f"l1={metrics['sparsity']:.4f}, "
                f"alive={metrics['alive_features']}/{d_sae}"
            )

    handle.remove()
    return sae
⚠ Warning

A major practical challenge with SAEs is "dead features": latent dimensions that never activate after initialization. With expansion factors of 16x or higher, 20% to 50% of features may die during training. Techniques like resampling dead features, using a ghost gradient for the encoder bias, or TopK activation functions (instead of ReLU) help mitigate this issue. Always monitor the fraction of alive features during SAE training.

4. Activation Patching

Activation patching (also called causal tracing or interchange intervention) is a technique for identifying which components of a model are responsible for a specific behavior. The idea is to run the model on two inputs (a "clean" input that triggers the behavior and a "corrupted" input that does not), and then selectively replace activations from one run with activations from the other. If replacing a specific component's activation restores the original behavior, that component is causally important.

# Activation Patching with TransformerLens
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch

# Load model with TransformerLens
model = HookedTransformer.from_pretrained("gpt2-small")

# Example: Which components know that "The Eiffel Tower is in" -> "Paris"?
clean_prompt = "The Eiffel Tower is in"
corrupted_prompt = "The Colosseum is in"  # different answer: "Rome"

# Get clean logits for the target token
clean_logits, clean_cache = model.run_with_cache(clean_prompt)
target_token = model.to_single_token(" Paris")
clean_logit_diff = (
    clean_logits[0, -1, target_token]
    - clean_logits[0, -1, model.to_single_token(" Rome")]
).item()

# Get corrupted cache
_, corrupted_cache = model.run_with_cache(corrupted_prompt)

# Patch each attention head and measure the effect
results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        # Define hook that patches this head's output
        def patch_hook(activation, hook, layer=layer, head=head):
            # Replace corrupted activation with clean activation
            activation[:, :, head, :] = clean_cache[
                hook.name
            ][:, :, head, :]
            return activation

        # Run corrupted input with this one head patched
        hook_name = f"blocks.{layer}.attn.hook_z"
        patched_logits = model.run_with_hooks(
            corrupted_prompt,
            fwd_hooks=[(hook_name, patch_hook)],
        )

        # Measure recovery of clean behavior
        patched_diff = (
            patched_logits[0, -1, target_token]
            - patched_logits[0, -1, model.to_single_token(" Rome")]
        ).item()

        results[layer, head] = patched_diff

print("Top 5 most important heads for 'Eiffel Tower -> Paris':")
flat = results.flatten()
top_indices = flat.argsort(descending=True)[:5]
for idx in top_indices:
    layer = idx // model.cfg.n_heads
    head = idx % model.cfg.n_heads
    print(f"  L{layer}H{head}: logit diff = {flat[idx]:.3f}")

5. TransformerLens and nnsight

Two primary libraries support mechanistic interpretability research on transformers. TransformerLens provides a re-implementation of common models with full hook access at every computational step. nnsight offers a different approach, allowing intervention on any PyTorch model without re-implementation.

FeatureTransformerLensnnsight
ApproachRe-implements models with hooks at every stepWraps existing PyTorch models with proxy access
Supported modelsGPT-2, GPT-Neo, Pythia, Llama, Mistral, GemmaAny PyTorch model
Hook granularityEvery attention sub-computation (Q, K, V, patterns, z)Any module input/output
CachingBuilt-in activation cachingProxy-based lazy evaluation
Best forDetailed mechanistic analysis of supported modelsQuick experiments on any architecture
Learning curveModerate (custom API)Lower (familiar PyTorch patterns)
# nnsight: Intervening on any PyTorch model
from nnsight import LanguageModel

# Wrap any HuggingFace model
model = LanguageModel("gpt2", device_map="auto")

# Use the tracing context to inspect and modify activations
with model.trace("The cat sat on the") as tracer:
    # Access any module's output
    layer_5_output = model.transformer.h[5].output[0]

    # Save it for inspection
    layer_5_output.save()

    # You can also modify activations in-place
    # model.transformer.h[5].mlp.output[:] *= 0  # ablate MLP

# Access saved activations after the trace
print(f"Layer 5 output shape: {layer_5_output.value.shape}")

# Activation patching with nnsight
clean_text = "The Eiffel Tower is in"
corrupt_text = "The Colosseum is in"

# Get clean activations
with model.trace(clean_text) as tracer:
    clean_resid = model.transformer.h[8].output[0].save()

# Patch corrupted run with clean activations at layer 8
with model.trace(corrupt_text) as tracer:
    model.transformer.h[8].output[0][:] = clean_resid.value
    patched_logits = model.lm_head.output.save()

print(f"Patched prediction: {patched_logits.value[0, -1].argmax()}")
📝 Note

Anthropic's interpretability research program has produced several landmark results using SAEs at scale. Their work on Claude models has identified millions of interpretable features, including features for specific concepts (Golden Gate Bridge, code bugs, deception), safety-relevant behaviors (refusal, harmful content detection), and abstract reasoning patterns. This demonstrates that SAE-based mechanistic interpretability can scale to production-sized models.

💡 Key Insight

Mechanistic interpretability is not just an academic exercise. Practical applications include: (1) understanding why a model produces specific outputs, enabling targeted debugging; (2) identifying and removing undesirable behaviors like deception or sycophancy; (3) verifying that safety training has actually modified the model's internal computations rather than just its surface behavior; and (4) steering model behavior by amplifying or suppressing specific features at inference time.

📝 Section Quiz

1. What is the "residual stream" view of transformers, and why does it matter for interpretability?
Show Answer
The residual stream view treats the residual connection as a central communication channel. Each attention head and MLP reads from and writes to this shared stream via addition. Because the contributions are additive, the final output is a sum of all component outputs, making it possible to attribute behavior to specific components. This linearity is the foundation for techniques like activation patching and direct logit attribution.
2. What is superposition, and why does it make interpretation difficult?
Show Answer
Superposition occurs when a model encodes more conceptual features than it has neurons by representing features as directions in activation space rather than individual neurons. This makes individual neurons polysemantic (responding to multiple unrelated concepts), which means you cannot interpret the model by looking at individual neurons. SAEs address this by projecting into a higher-dimensional space where each dimension corresponds to a single feature.
3. How does a sparse autoencoder extract interpretable features?
Show Answer
An SAE maps model activations (d_model dimensions) into a much larger space (d_sae dimensions, typically 4x to 64x larger) with a sparsity constraint (ReLU or TopK). The sparsity ensures that only a few dimensions are active for any input. Each active dimension ideally corresponds to a single interpretable feature. The decoder weights define the directions in activation space that each feature corresponds to. The SAE is trained to minimize reconstruction error while maintaining sparsity.
4. What is activation patching, and what does it reveal?
Show Answer
Activation patching runs the model on a clean input and a corrupted input, then selectively swaps activations between the two runs at specific components. If replacing a component's corrupted activation with its clean activation restores the original behavior, that component is causally responsible. This provides stronger evidence than correlation-based methods because it tests a counterfactual: what would happen if this component computed differently?
5. What is the "dead features" problem in SAE training, and how is it addressed?
Show Answer
Dead features are SAE latent dimensions that never activate after initialization, effectively wasting capacity. With large expansion factors, 20% to 50% of features may die. Solutions include: (1) resampling dead features by reinitializing their weights using poorly reconstructed examples, (2) ghost gradients that provide training signal to the encoder even for inactive features, and (3) TopK activation functions that guarantee a fixed number of active features per input rather than relying on ReLU thresholding.

✅ Key Takeaways