Claude-skill-registry attribution-patching

Gradient-based approximation to activation patching for scalable circuit analysis. Use when activation patching is too slow or when analyzing many components simultaneously.

install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/attribution-patching" ~/.claude/skills/majiayu000-claude-skill-registry-attribution-patching && rm -rf "$T"
manifest: skills/data/attribution-patching/SKILL.md
source content

Attribution Patching

Attribution patching uses gradients to approximate activation patching results in a single backward pass, making it practical to analyze thousands of components simultaneously.

Core Idea

Instead of running separate forward passes for each component:

  1. Run clean and corrupted forward passes
  2. Compute gradients of the metric w.r.t. corrupted activations
  3. Multiply gradients by (clean - corrupted) activation differences

This linear approximation works when clean and corrupted runs are similar.

Mathematical Formula

attribution(component) = grad_corrupted(metric) * (clean_activation - corrupted_activation)

Setup

from nnsight import LanguageModel
import torch

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

correct_token = model.tokenizer(" John")["input_ids"][0]
incorrect_token = model.tokenizer(" Mary")["input_ids"][0]

def logit_diff(logits):
    return logits[0, -1, correct_token] - logits[0, -1, incorrect_token]

Basic Attribution Patching

n_layers = len(model.transformer.h)

clean_acts = []
corrupted_acts = []
corrupted_grads = []

# Clean forward pass - save activations
with model.trace(clean_prompt):
    for layer in model.transformer.h:
        act = layer.output[0]
        clean_acts.append(act.save())

# Corrupted forward + backward pass
with model.trace(corrupted_prompt):
    # Register intermediate values in forward order
    for layer in model.transformer.h:
        act = layer.output[0]
        act.requires_grad = True
        corrupted_acts.append(act.save())

    # Compute metric
    logits = model.lm_head.output
    metric = logit_diff(logits)

    # Access gradients in REVERSE order within backward context
    with metric.backward():
        for layer in reversed(model.transformer.h):
            corrupted_grads.insert(0, layer.output[0].grad.save())

# Compute attributions
attributions = []
for i in range(n_layers):
    clean = clean_acts[i].value
    corrupted = corrupted_acts[i].value
    grad = corrupted_grads[i].value

    # Attribution = grad * (clean - corrupted)
    attr = (grad * (clean - corrupted)).sum()
    attributions.append(attr.item())

attributions = torch.tensor(attributions)

Per-Position Attribution

seq_len = clean_acts[0].value.shape[1]
position_attrs = torch.zeros(n_layers, seq_len)

for layer_idx in range(n_layers):
    clean = clean_acts[layer_idx].value
    corrupted = corrupted_acts[layer_idx].value
    grad = corrupted_grads[layer_idx].value

    # Sum over hidden dimension only, keep position
    diff = clean - corrupted
    attr = (grad * diff).sum(dim=-1).squeeze()  # [seq_len]
    position_attrs[layer_idx] = attr

Attention Head Attribution

from einops import rearrange

n_heads = model.config.n_head
head_dim = model.config.n_embd // n_heads
head_attrs = torch.zeros(n_layers, n_heads)

# Collect clean attention outputs
clean_attn = []
with model.trace(clean_prompt):
    for layer in model.transformer.h:
        attn_out = layer.attn.c_proj.input[0][0]  # Before projection
        clean_attn.append(attn_out.save())

# Collect corrupted attention outputs and gradients
corrupted_attn = []
attn_grads = []

with model.trace(corrupted_prompt):
    # Register intermediate values in forward order
    for layer in model.transformer.h:
        attn_out = layer.attn.c_proj.input[0][0]
        attn_out.requires_grad = True
        corrupted_attn.append(attn_out.save())

    metric = logit_diff(model.lm_head.output)

    # Access gradients in REVERSE order within backward context
    with metric.backward():
        for layer in reversed(model.transformer.h):
            attn_grads.insert(0, layer.attn.c_proj.input[0][0].grad.save())

# Compute per-head attributions
for layer_idx in range(n_layers):
    clean = clean_attn[layer_idx].value
    corrupted = corrupted_attn[layer_idx].value
    grad = attn_grads[layer_idx].value

    # Reshape to [batch, seq, heads, head_dim]
    clean_heads = rearrange(clean, 'b s (h d) -> b s h d', h=n_heads)
    corrupted_heads = rearrange(corrupted, 'b s (h d) -> b s h d', h=n_heads)
    grad_heads = rearrange(grad, 'b s (h d) -> b s h d', h=n_heads)

    # Attribution per head
    diff = clean_heads - corrupted_heads
    attr = (grad_heads * diff).sum(dim=(0, 1, 3))  # Sum batch, seq, head_dim
    head_attrs[layer_idx] = attr

Efficient Batched Version

Process both prompts in a single forward pass using batching:

# Batch both prompts together in a single trace
all_acts = []
all_grads = []

with model.trace([clean_prompt, corrupted_prompt]):
    # Register intermediate values in forward order
    for layer in model.transformer.h:
        act = layer.output[0]
        act.requires_grad = True
        all_acts.append(act.save())

    logits = model.lm_head.output
    # Metric on corrupted (index 1)
    metric = logit_diff(logits[1:2])

    # Access gradients in REVERSE order within backward context
    with metric.backward():
        for layer in reversed(model.transformer.h):
            all_grads.insert(0, layer.output[0].grad.save())

# Split clean/corrupted and compute attributions
attributions = []
for i in range(n_layers):
    acts = all_acts[i].value
    grads = all_grads[i].value

    clean = acts[0:1]
    corrupted = acts[1:2]
    grad = grads[1:2]  # Gradient is only for corrupted

    attr = (grad * (clean - corrupted)).sum()
    attributions.append(attr.item())

Comparison with Activation Patching

AspectActivation PatchingAttribution Patching
AccuracyExactApproximation
SpeedO(n_components) forwardsO(1) forward + backward
MemoryLower per runHigher (stores grads)
Best forFew componentsMany components

Validation

Compare attribution results against ground truth patching:

# Scatter plot: attribution vs actual patching effect
import matplotlib.pyplot as plt

plt.scatter(attributions, actual_patching_results)
plt.xlabel("Attribution Score")
plt.ylabel("Actual Patching Effect")
plt.title("Attribution vs Patching Correlation")
correlation = torch.corrcoef(torch.stack([attributions, actual_patching_results]))[0, 1]
plt.text(0.1, 0.9, f"r = {correlation:.3f}", transform=plt.gca().transAxes)

When to Use

  • Use attribution patching: Initial exploration, many components, large models
  • Use activation patching: Validating specific components, exact measurements needed
  • Combine both: Attribution for screening, patching for confirmation