Interpretability is not just a research curiosity; it is a practical toolkit for building better models. Feature attribution methods explain individual predictions, representation engineering steers model behavior without retraining, and model editing techniques (ROME, MEMIT) surgically modify specific knowledge stored in weights. These tools help practitioners debug hallucinations, remove unwanted biases, update stale information, and verify that models behave as intended before deployment.
1. Feature Attribution Methods
Feature attribution methods assign an importance score to each input token, answering the question: "How much did each token contribute to the model's prediction?" Unlike attention visualization (which shows where the model looks), attribution methods track the actual causal influence of each input on the output through the entire network.
1.1 Integrated Gradients
Integrated Gradients (Sundararajan et al., 2017) computes attribution by integrating the gradient of the output with respect to the input along a straight-line path from a baseline (typically the zero embedding) to the actual input. The method satisfies two desirable axioms: sensitivity (if changing an input changes the output, that input gets non-zero attribution) and implementation invariance (the attribution depends only on the function, not its implementation details).
# Integrated Gradients for Token Attribution
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def integrated_gradients(
model,
tokenizer,
text,
target_token_idx=-1,
n_steps=50,
internal_batch_size=5,
):
"""
Compute Integrated Gradients attribution for each input token.
Args:
model: language model
tokenizer: tokenizer
text: input text
target_token_idx: which output position to explain (-1 = last)
n_steps: number of interpolation steps (higher = more accurate)
Returns:
attributions: per-token attribution scores
tokens: list of input tokens
"""
model.eval()
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# Get embeddings
embedding_layer = model.get_input_embeddings()
input_embeds = embedding_layer(input_ids).detach()
# Baseline: zero embeddings
baseline = torch.zeros_like(input_embeds)
# Interpolate between baseline and input
alphas = torch.linspace(0, 1, n_steps).unsqueeze(1).unsqueeze(2)
interpolated = baseline + alphas * (input_embeds - baseline)
# Shape: (n_steps, seq_len, hidden_dim)
# Compute gradients at each interpolation point
all_grads = []
for i in range(0, n_steps, internal_batch_size):
batch = interpolated[i:i + internal_batch_size]
batch.requires_grad_(True)
outputs = model(inputs_embeds=batch)
logits = outputs.logits
# Get the predicted token's logit at target position
target_logit = logits[:, target_token_idx, :].max(dim=-1).values
target_logit.sum().backward()
all_grads.append(batch.grad.detach())
grads = torch.cat(all_grads, dim=0) # (n_steps, seq_len, hidden)
# Integrate: average gradients, then multiply by (input - baseline)
avg_grads = grads.mean(dim=0) # (seq_len, hidden)
ig = (input_embeds.squeeze(0) - baseline.squeeze(0)) * avg_grads
# Sum over hidden dimension to get per-token scores
attributions = ig.sum(dim=-1).numpy()
return attributions, tokens
# Example usage
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
attrs, tokens = integrated_gradients(
model, tokenizer, "The capital of France is"
)
print("Token attributions for next-token prediction:")
for token, score in zip(tokens, attrs):
bar = "+" * int(abs(score) * 20)
direction = "+" if score > 0 else "-"
print(f" {token:15s} {direction}{bar} ({score:.4f})")
1.2 SHAP for Language Models
SHAP (SHapley Additive exPlanations) adapts Shapley values from cooperative game theory to feature attribution. Each token's SHAP value represents the average marginal contribution of that token across all possible subsets of input tokens. While computationally expensive for long sequences, SHAP provides theoretically grounded attributions with strong guarantees.
# SHAP-based attribution for language models
import shap
def explain_with_shap(model, tokenizer, texts, max_evals=500):
"""Use SHAP to explain model predictions."""
# Create a SHAP explainer for the model
def model_predict(texts_list):
"""Prediction function for SHAP."""
results = []
for text in texts_list:
inputs = tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = model(**inputs).logits[0, -1]
probs = torch.softmax(logits, dim=-1)
results.append(probs.numpy())
return np.array(results)
# Partition explainer works well for text
explainer = shap.Explainer(
model_predict,
tokenizer,
output_names=tokenizer.get_vocab(),
)
# Compute SHAP values
shap_values = explainer(texts, max_evals=max_evals)
return shap_values
# Visualize SHAP attributions
shap_vals = explain_with_shap(
model, tokenizer,
["The movie was absolutely wonderful and I loved every moment"]
)
shap.plots.text(shap_vals[0])
For production use, Integrated Gradients is typically preferred over SHAP for language models because it scales linearly with input length (O(n_steps * forward_pass)), while exact SHAP values require exponentially many evaluations (2^n for n tokens). Approximate SHAP methods reduce this cost but introduce estimation noise.
2. Representation Engineering
Representation engineering (RepE) steers model behavior by modifying internal representations at inference time. Instead of retraining the model, you identify a "control vector" in activation space that corresponds to a specific behavior (such as honesty, verbosity, or formality) and add or subtract this vector during generation to increase or decrease that behavior.
# Representation Engineering: Control Vectors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def extract_control_vector(
model,
tokenizer,
positive_prompts,
negative_prompts,
layer_idx,
):
"""
Extract a control vector by contrasting positive and negative prompts.
The control vector is the mean activation difference at a specific layer.
"""
def get_mean_activation(prompts, layer_idx):
activations = []
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# Take the last token's hidden state at the target layer
hidden = outputs.hidden_states[layer_idx][0, -1, :]
activations.append(hidden)
return torch.stack(activations).mean(dim=0)
pos_mean = get_mean_activation(positive_prompts, layer_idx)
neg_mean = get_mean_activation(negative_prompts, layer_idx)
control_vector = pos_mean - neg_mean
# Normalize to unit length
control_vector = control_vector / control_vector.norm()
return control_vector
# Example: honesty control vector
honest_prompts = [
"I need to give an honest, truthful answer to this question:",
"Let me think carefully and give an accurate response:",
"I want to be straightforward and factual:",
]
dishonest_prompts = [
"I should make up a convincing-sounding answer:",
"Let me say whatever sounds good regardless of truth:",
"I want to tell them what they want to hear:",
]
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
honesty_vector = extract_control_vector(
model, tokenizer, honest_prompts, dishonest_prompts, layer_idx=8
)
def generate_with_steering(
model, tokenizer, prompt, control_vector, layer_idx,
alpha=1.5, max_new_tokens=100,
):
"""Generate text while adding the control vector at each step."""
def steering_hook(module, input, output):
# Add control vector to hidden states
hidden = output[0]
hidden[:, -1, :] += alpha * control_vector
return (hidden,) + output[1:]
handle = model.transformer.h[layer_idx].register_forward_hook(steering_hook)
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
handle.remove()
return tokenizer.decode(outputs[0], skip_special_tokens=True)
3. Model Editing: ROME and MEMIT
Model editing techniques surgically modify specific factual associations stored in model weights without affecting other knowledge. ROME (Rank-One Model Editing) targets a single feed-forward layer to update one fact. MEMIT (Mass-Editing Memory In a Transformer) extends this to edit thousands of facts simultaneously.
ROME is based on the discovery that factual associations are primarily stored in the MLP layers of transformers, specifically in the key-value matrices of the feed-forward network. The MLP acts as an associative memory where the first linear layer (the "key") matches patterns and the second linear layer (the "value") stores the associated information. ROME modifies the value matrix with a rank-one update that changes exactly one fact while preserving all others.
# Model Editing with ROME (using the rome library)
# pip install rome
from rome import ROMEHyperParams, apply_rome_to_model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
# Before editing
prompt = "The president of the United States is"
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=10)
print(f"Before: {tokenizer.decode(output[0])}")
# Define the edit
edit_request = {
"prompt": "The president of the United States is",
"subject": "The president of the United States",
"target_new": " Elon Musk", # hypothetical edit
}
# Apply ROME
hparams = ROMEHyperParams.from_name("gpt2-xl")
edited_model, _ = apply_rome_to_model(
model, tokenizer, [edit_request], hparams
)
# After editing
output = edited_model.generate(**inputs, max_new_tokens=10)
print(f"After: {tokenizer.decode(output[0])}")
# Verify specificity: other knowledge should be preserved
other_prompt = "The capital of France is"
other_inputs = tokenizer(other_prompt, return_tensors="pt")
other_output = edited_model.generate(**other_inputs, max_new_tokens=10)
print(f"Other fact: {tokenizer.decode(other_output[0])}")
# Should still say "Paris"
| Method | Edits per Run | Target Component | Preservation | Scalability |
|---|---|---|---|---|
| ROME | 1 | Single MLP layer (rank-1 update) | Good for single edits | Slow for many edits (sequential) |
| MEMIT | 1,000+ | Multiple MLP layers (distributed) | Good even with many edits | Handles batch edits efficiently |
| Fine-tuning | Unlimited | All parameters | Poor (catastrophic forgetting) | Good but destroys other knowledge |
| GRACE | Unlimited | Adapter codebook | Good (no weight changes) | Inference overhead grows with edits |
Model editing is powerful but fragile. Edits can have unintended side effects: changing "The president is X" might also change answers to related questions like "Who lives in the White House?" in unpredictable ways. The "ripple effect" of knowledge edits is an active research area. Always validate edits against a comprehensive test suite that includes related and unrelated facts.
4. Concept Erasure
Concept erasure removes specific information from model representations, ensuring the model cannot use that information for any downstream task. Unlike model editing (which changes a fact to a different value), concept erasure eliminates the information entirely. Applications include removing protected attributes (gender, race) from embeddings to prevent discriminatory predictions.
# Concept Erasure with LEACE
# pip install concept-erasure
from concept_erasure import LeaceFitter
import torch
def erase_concept(
hidden_states: torch.Tensor,
concept_labels: torch.Tensor,
) -> torch.Tensor:
"""
Erase a binary concept from hidden states using LEACE.
LEACE (LEAst-squares Concept Erasure) finds the linear subspace
that encodes the concept and projects it out, guaranteeing
that no linear classifier can recover the concept from the
resulting representations.
"""
fitter = LeaceFitter.fit(hidden_states, concept_labels)
erased = fitter.transform(hidden_states)
return erased
# Example: erase gender information from embeddings
# hidden_states: (num_samples, hidden_dim)
# gender_labels: (num_samples,) binary labels
erased_states = erase_concept(hidden_states, gender_labels)
# Verify: train a linear probe on erased representations
from sklearn.linear_model import LogisticRegression
probe_before = LogisticRegression().fit(
hidden_states.numpy(), gender_labels.numpy()
)
probe_after = LogisticRegression().fit(
erased_states.numpy(), gender_labels.numpy()
)
print(f"Gender accuracy before erasure: {probe_before.score(X_test, y_test):.3f}")
print(f"Gender accuracy after erasure: {probe_after.score(X_test_erased, y_test):.3f}")
# After erasure, accuracy should be ~50% (random chance)
5. Interpretability for Debugging
Beyond research, interpretability tools serve as practical debugging instruments. When a model produces incorrect or unexpected outputs, these tools help diagnose the root cause by identifying which components contributed to the error and what information the model relied on.
In practice, the most common interpretability-based debugging pattern is: (1) identify a failure case, (2) use Integrated Gradients to find which input tokens are driving the incorrect output, (3) use logit lens to see which layers introduce the error, (4) decide whether to fix via prompt engineering, representation steering, model editing, or targeted fine-tuning. This workflow often reveals that hallucinations are caused by specific attention patterns that retrieve incorrect context.
📝 Section Quiz
Show Answer
Show Answer
Show Answer
Show Answer
Show Answer
✅ Key Takeaways
- Integrated Gradients and SHAP provide principled, axiom-satisfying methods for attributing model predictions to specific input tokens.
- Representation engineering steers model behavior at inference time by adding learned control vectors, offering a lightweight alternative to retraining.
- ROME and MEMIT enable surgical editing of specific facts in model weights, but ripple effects on related knowledge require careful validation.
- Concept erasure (LEACE) provides mathematical guarantees that specific information is removed from representations, enabling provably fair predictions.
- Interpretability tools form a practical debugging workflow: attribution identifies contributing inputs, activation patching localizes responsible components, and editing or steering applies the fix.
- The choice between these tools depends on the use case: attribution for explaining predictions, steering for behavior modification, and editing for knowledge correction.