Claude-skill-registry activation-patching

Causal intervention via activation patching to identify important model components. Use when determining which layers, heads, or positions are causally responsible for model behavior.

install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/activation-patching" ~/.claude/skills/majiayu000-claude-skill-registry-activation-patching && rm -rf "$T"
manifest: skills/data/activation-patching/SKILL.md
source content

Activation Patching

Activation patching is a causal intervention technique that identifies which model components are responsible for specific behaviors by swapping activations between different inputs.

Core Concept

  1. Clean run: Run model on prompt that produces desired behavior
  2. Corrupted run: Run on modified prompt that changes the behavior
  3. Patch: Replace corrupted activations with clean ones, measure if behavior is restored

If patching a component restores the clean behavior, that component is causally important.

Basic Setup

from nnsight import LanguageModel
import torch

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

# Indirect Object Identification (IOI) task
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

# Target tokens
correct_token = model.tokenizer(" John")["input_ids"][0]   # Clean answer
incorrect_token = model.tokenizer(" Mary")["input_ids"][0]  # Corrupted answer

Metric: Logit Difference

def logit_diff(logits, correct_idx, incorrect_idx):
    """Measure how much model prefers correct over incorrect token."""
    return (logits[0, -1, correct_idx] - logits[0, -1, incorrect_idx]).item()

Three-Run Patching Pattern

n_layers = len(model.transformer.h)
results = torch.zeros(n_layers)

# Run 1: Clean - save activations
with model.trace(clean_prompt):
    clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]
    clean_logits = model.lm_head.output.save()

# Run 2: Corrupted baseline
with model.trace(corrupted_prompt):
    corrupted_logits = model.lm_head.output.save()

# Runs 3+: Patch each layer (separate forward passes)
for layer_idx in range(n_layers):
    with model.trace(corrupted_prompt):
        # Replace corrupted activation with clean
        model.transformer.h[layer_idx].output[0][:] = clean_hiddens[layer_idx]
        patched_logits = model.lm_head.output.save()
    results[layer_idx] = logit_diff(patched_logits.value, correct_token, incorrect_token)

# Normalize results
clean_diff = logit_diff(clean_logits.value, correct_token, incorrect_token)
corrupted_diff = logit_diff(corrupted_logits.value, correct_token, incorrect_token)
normalized = (results - corrupted_diff) / (clean_diff - corrupted_diff)

Position-Specific Patching

Patch only specific token positions:

seq_len = len(model.tokenizer.encode(clean_prompt))
results = torch.zeros(n_layers, seq_len)

# Clean run - save activations
with model.trace(clean_prompt):
    clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# Patch each layer x position (separate forward passes)
for layer_idx in range(n_layers):
    for pos_idx in range(seq_len):
        with model.trace(corrupted_prompt):
            # Patch only this position
            model.transformer.h[layer_idx].output[0][:, pos_idx, :] = \
                clean_hiddens[layer_idx][:, pos_idx, :]
            patched_logits = model.lm_head.output.save()
        results[layer_idx, pos_idx] = logit_diff(
            patched_logits.value, correct_token, incorrect_token
        )

Attention Head Patching

Patch individual attention heads:

n_heads = model.config.n_head
head_dim = model.config.n_embd // n_heads
results = torch.zeros(n_layers, n_heads)

# Clean run - save attention outputs (before projection)
with model.trace(clean_prompt):
    clean_attn = [layer.attn.c_proj.input[0][0].save()
                  for layer in model.transformer.h]

# Patch each layer x head (separate forward passes)
for layer_idx in range(n_layers):
    for head_idx in range(n_heads):
        with model.trace(corrupted_prompt):
            # Patch single head's output
            start = head_idx * head_dim
            end = (head_idx + 1) * head_dim
            model.transformer.h[layer_idx].attn.c_proj.input[0][0][:, :, start:end] = \
                clean_attn[layer_idx][:, :, start:end]
            patched_logits = model.lm_head.output.save()
        results[layer_idx, head_idx] = logit_diff(
            patched_logits.value, correct_token, incorrect_token
        )

Noising (Reverse Patching)

Instead of restoring clean activations, corrupt clean activations:

# Corrupted run - save activations
with model.trace(corrupted_prompt):
    corrupted_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# For each layer, inject corrupted activation into clean run
noising_results = torch.zeros(n_layers)
for layer_idx in range(n_layers):
    with model.trace(clean_prompt):
        # Inject corrupted activation into clean run
        model.transformer.h[layer_idx].output[0][:] = corrupted_hiddens[layer_idx]
        noised_logits = model.lm_head.output.save()
    noising_results[layer_idx] = logit_diff(noised_logits.value, correct_token, incorrect_token)

Visualization

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12, 8))
sns.heatmap(
    results.numpy(),
    xticklabels=[f"Pos {i}" for i in range(seq_len)],
    yticklabels=[f"Layer {i}" for i in range(n_layers)],
    cmap="RdBu_r",
    center=0,
    annot=True,
    fmt=".2f"
)
plt.title("Activation Patching Results")
plt.xlabel("Token Position")
plt.ylabel("Layer")
plt.tight_layout()
plt.show()

Interpretation

  • High positive values: Component is important for correct behavior
  • Values near 0: Component doesn't affect this behavior
  • Negative values: Component actively pushes toward wrong answer
  • Clusters of importance: Suggest circuits or computational stages