Claude-skill-registry attribution-patching
Gradient-based approximation to activation patching for scalable circuit analysis. Use when activation patching is too slow or when analyzing many components simultaneously.
install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/attribution-patching" ~/.claude/skills/majiayu000-claude-skill-registry-attribution-patching && rm -rf "$T"
manifest:
skills/data/attribution-patching/SKILL.mdsource content
Attribution Patching
Attribution patching uses gradients to approximate activation patching results in a single backward pass, making it practical to analyze thousands of components simultaneously.
Core Idea
Instead of running separate forward passes for each component:
- Run clean and corrupted forward passes
- Compute gradients of the metric w.r.t. corrupted activations
- Multiply gradients by (clean - corrupted) activation differences
This linear approximation works when clean and corrupted runs are similar.
Mathematical Formula
attribution(component) = grad_corrupted(metric) * (clean_activation - corrupted_activation)
Setup
from nnsight import LanguageModel import torch model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" correct_token = model.tokenizer(" John")["input_ids"][0] incorrect_token = model.tokenizer(" Mary")["input_ids"][0] def logit_diff(logits): return logits[0, -1, correct_token] - logits[0, -1, incorrect_token]
Basic Attribution Patching
n_layers = len(model.transformer.h) clean_acts = [] corrupted_acts = [] corrupted_grads = [] # Clean forward pass - save activations with model.trace(clean_prompt): for layer in model.transformer.h: act = layer.output[0] clean_acts.append(act.save()) # Corrupted forward + backward pass with model.trace(corrupted_prompt): # Register intermediate values in forward order for layer in model.transformer.h: act = layer.output[0] act.requires_grad = True corrupted_acts.append(act.save()) # Compute metric logits = model.lm_head.output metric = logit_diff(logits) # Access gradients in REVERSE order within backward context with metric.backward(): for layer in reversed(model.transformer.h): corrupted_grads.insert(0, layer.output[0].grad.save()) # Compute attributions attributions = [] for i in range(n_layers): clean = clean_acts[i].value corrupted = corrupted_acts[i].value grad = corrupted_grads[i].value # Attribution = grad * (clean - corrupted) attr = (grad * (clean - corrupted)).sum() attributions.append(attr.item()) attributions = torch.tensor(attributions)
Per-Position Attribution
seq_len = clean_acts[0].value.shape[1] position_attrs = torch.zeros(n_layers, seq_len) for layer_idx in range(n_layers): clean = clean_acts[layer_idx].value corrupted = corrupted_acts[layer_idx].value grad = corrupted_grads[layer_idx].value # Sum over hidden dimension only, keep position diff = clean - corrupted attr = (grad * diff).sum(dim=-1).squeeze() # [seq_len] position_attrs[layer_idx] = attr
Attention Head Attribution
from einops import rearrange n_heads = model.config.n_head head_dim = model.config.n_embd // n_heads head_attrs = torch.zeros(n_layers, n_heads) # Collect clean attention outputs clean_attn = [] with model.trace(clean_prompt): for layer in model.transformer.h: attn_out = layer.attn.c_proj.input[0][0] # Before projection clean_attn.append(attn_out.save()) # Collect corrupted attention outputs and gradients corrupted_attn = [] attn_grads = [] with model.trace(corrupted_prompt): # Register intermediate values in forward order for layer in model.transformer.h: attn_out = layer.attn.c_proj.input[0][0] attn_out.requires_grad = True corrupted_attn.append(attn_out.save()) metric = logit_diff(model.lm_head.output) # Access gradients in REVERSE order within backward context with metric.backward(): for layer in reversed(model.transformer.h): attn_grads.insert(0, layer.attn.c_proj.input[0][0].grad.save()) # Compute per-head attributions for layer_idx in range(n_layers): clean = clean_attn[layer_idx].value corrupted = corrupted_attn[layer_idx].value grad = attn_grads[layer_idx].value # Reshape to [batch, seq, heads, head_dim] clean_heads = rearrange(clean, 'b s (h d) -> b s h d', h=n_heads) corrupted_heads = rearrange(corrupted, 'b s (h d) -> b s h d', h=n_heads) grad_heads = rearrange(grad, 'b s (h d) -> b s h d', h=n_heads) # Attribution per head diff = clean_heads - corrupted_heads attr = (grad_heads * diff).sum(dim=(0, 1, 3)) # Sum batch, seq, head_dim head_attrs[layer_idx] = attr
Efficient Batched Version
Process both prompts in a single forward pass using batching:
# Batch both prompts together in a single trace all_acts = [] all_grads = [] with model.trace([clean_prompt, corrupted_prompt]): # Register intermediate values in forward order for layer in model.transformer.h: act = layer.output[0] act.requires_grad = True all_acts.append(act.save()) logits = model.lm_head.output # Metric on corrupted (index 1) metric = logit_diff(logits[1:2]) # Access gradients in REVERSE order within backward context with metric.backward(): for layer in reversed(model.transformer.h): all_grads.insert(0, layer.output[0].grad.save()) # Split clean/corrupted and compute attributions attributions = [] for i in range(n_layers): acts = all_acts[i].value grads = all_grads[i].value clean = acts[0:1] corrupted = acts[1:2] grad = grads[1:2] # Gradient is only for corrupted attr = (grad * (clean - corrupted)).sum() attributions.append(attr.item())
Comparison with Activation Patching
| Aspect | Activation Patching | Attribution Patching |
|---|---|---|
| Accuracy | Exact | Approximation |
| Speed | O(n_components) forwards | O(1) forward + backward |
| Memory | Lower per run | Higher (stores grads) |
| Best for | Few components | Many components |
Validation
Compare attribution results against ground truth patching:
# Scatter plot: attribution vs actual patching effect import matplotlib.pyplot as plt plt.scatter(attributions, actual_patching_results) plt.xlabel("Attribution Score") plt.ylabel("Actual Patching Effect") plt.title("Attribution vs Patching Correlation") correlation = torch.corrcoef(torch.stack([attributions, actual_patching_results]))[0, 1] plt.text(0.1, 0.9, f"r = {correlation:.3f}", transform=plt.gca().transAxes)
When to Use
- Use attribution patching: Initial exploration, many components, large models
- Use activation patching: Validating specific components, exact measurements needed
- Combine both: Attribution for screening, patching for confirmation