Claude-skill-registry activation-patching
Causal intervention via activation patching to identify important model components. Use when determining which layers, heads, or positions are causally responsible for model behavior.
install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/activation-patching" ~/.claude/skills/majiayu000-claude-skill-registry-activation-patching && rm -rf "$T"
manifest:
skills/data/activation-patching/SKILL.mdsource content
Activation Patching
Activation patching is a causal intervention technique that identifies which model components are responsible for specific behaviors by swapping activations between different inputs.
Core Concept
- Clean run: Run model on prompt that produces desired behavior
- Corrupted run: Run on modified prompt that changes the behavior
- Patch: Replace corrupted activations with clean ones, measure if behavior is restored
If patching a component restores the clean behavior, that component is causally important.
Basic Setup
from nnsight import LanguageModel import torch model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) # Indirect Object Identification (IOI) task clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" # Target tokens correct_token = model.tokenizer(" John")["input_ids"][0] # Clean answer incorrect_token = model.tokenizer(" Mary")["input_ids"][0] # Corrupted answer
Metric: Logit Difference
def logit_diff(logits, correct_idx, incorrect_idx): """Measure how much model prefers correct over incorrect token.""" return (logits[0, -1, correct_idx] - logits[0, -1, incorrect_idx]).item()
Three-Run Patching Pattern
n_layers = len(model.transformer.h) results = torch.zeros(n_layers) # Run 1: Clean - save activations with model.trace(clean_prompt): clean_hiddens = [layer.output[0].save() for layer in model.transformer.h] clean_logits = model.lm_head.output.save() # Run 2: Corrupted baseline with model.trace(corrupted_prompt): corrupted_logits = model.lm_head.output.save() # Runs 3+: Patch each layer (separate forward passes) for layer_idx in range(n_layers): with model.trace(corrupted_prompt): # Replace corrupted activation with clean model.transformer.h[layer_idx].output[0][:] = clean_hiddens[layer_idx] patched_logits = model.lm_head.output.save() results[layer_idx] = logit_diff(patched_logits.value, correct_token, incorrect_token) # Normalize results clean_diff = logit_diff(clean_logits.value, correct_token, incorrect_token) corrupted_diff = logit_diff(corrupted_logits.value, correct_token, incorrect_token) normalized = (results - corrupted_diff) / (clean_diff - corrupted_diff)
Position-Specific Patching
Patch only specific token positions:
seq_len = len(model.tokenizer.encode(clean_prompt)) results = torch.zeros(n_layers, seq_len) # Clean run - save activations with model.trace(clean_prompt): clean_hiddens = [layer.output[0].save() for layer in model.transformer.h] # Patch each layer x position (separate forward passes) for layer_idx in range(n_layers): for pos_idx in range(seq_len): with model.trace(corrupted_prompt): # Patch only this position model.transformer.h[layer_idx].output[0][:, pos_idx, :] = \ clean_hiddens[layer_idx][:, pos_idx, :] patched_logits = model.lm_head.output.save() results[layer_idx, pos_idx] = logit_diff( patched_logits.value, correct_token, incorrect_token )
Attention Head Patching
Patch individual attention heads:
n_heads = model.config.n_head head_dim = model.config.n_embd // n_heads results = torch.zeros(n_layers, n_heads) # Clean run - save attention outputs (before projection) with model.trace(clean_prompt): clean_attn = [layer.attn.c_proj.input[0][0].save() for layer in model.transformer.h] # Patch each layer x head (separate forward passes) for layer_idx in range(n_layers): for head_idx in range(n_heads): with model.trace(corrupted_prompt): # Patch single head's output start = head_idx * head_dim end = (head_idx + 1) * head_dim model.transformer.h[layer_idx].attn.c_proj.input[0][0][:, :, start:end] = \ clean_attn[layer_idx][:, :, start:end] patched_logits = model.lm_head.output.save() results[layer_idx, head_idx] = logit_diff( patched_logits.value, correct_token, incorrect_token )
Noising (Reverse Patching)
Instead of restoring clean activations, corrupt clean activations:
# Corrupted run - save activations with model.trace(corrupted_prompt): corrupted_hiddens = [layer.output[0].save() for layer in model.transformer.h] # For each layer, inject corrupted activation into clean run noising_results = torch.zeros(n_layers) for layer_idx in range(n_layers): with model.trace(clean_prompt): # Inject corrupted activation into clean run model.transformer.h[layer_idx].output[0][:] = corrupted_hiddens[layer_idx] noised_logits = model.lm_head.output.save() noising_results[layer_idx] = logit_diff(noised_logits.value, correct_token, incorrect_token)
Visualization
import matplotlib.pyplot as plt import seaborn as sns plt.figure(figsize=(12, 8)) sns.heatmap( results.numpy(), xticklabels=[f"Pos {i}" for i in range(seq_len)], yticklabels=[f"Layer {i}" for i in range(n_layers)], cmap="RdBu_r", center=0, annot=True, fmt=".2f" ) plt.title("Activation Patching Results") plt.xlabel("Token Position") plt.ylabel("Layer") plt.tight_layout() plt.show()
Interpretation
- High positive values: Component is important for correct behavior
- Values near 0: Component doesn't affect this behavior
- Negative values: Component actively pushes toward wrong answer
- Clusters of importance: Suggest circuits or computational stages