Claude-skill-registry causal-tracing

Causal mediation analysis to identify which model components mediate specific behaviors. Use when investigating how information flows through the network and which neurons or layers are causally responsible for outputs.

install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/causal-tracing" ~/.claude/skills/majiayu000-claude-skill-registry-causal-tracing && rm -rf "$T"
manifest: skills/data/causal-tracing/SKILL.md
source content

Causal Tracing

Causal tracing (causal mediation analysis) identifies which intermediate computations causally mediate the relationship between inputs and outputs. It reveals not just what correlates with behavior, but what causes it.

Core Concepts

Three Types of Causal Effects

  1. Total Effect: Change in output when modifying input
  2. Direct Effect: Effect of restoring a component from clean to corrupted run
  3. Indirect Effect: Effect of corrupting a component in an otherwise clean run

The Interchange Intervention

Swap activations between two runs to test causal relationships:

  • Source run: Produces the activation value
  • Base run: Receives the swapped activation

Setup

from nnsight import LanguageModel
import torch

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

# Factual recall task
base_prompt = "The Eiffel Tower is located in"    # Expects: Paris
source_prompt = "The Colosseum is located in"      # Expects: Rome

# Get target tokens
paris_token = model.tokenizer(" Paris")["input_ids"][0]
rome_token = model.tokenizer(" Rome")["input_ids"][0]

Computing Total Effect

with model.trace() as tracer:
    with tracer.invoke(base_prompt):
        base_logits = model.lm_head.output.save()

    with tracer.invoke(source_prompt):
        source_logits = model.lm_head.output.save()

base_prob = torch.softmax(base_logits.value[0, -1], dim=-1)[paris_token]
source_prob = torch.softmax(source_logits.value[0, -1], dim=-1)[rome_token]

total_effect = base_prob - source_prob  # How much does changing input change output?

Direct Effect (Restoration)

Does restoring a component from source restore source behavior?

n_layers = len(model.transformer.h)
direct_effects = torch.zeros(n_layers)

# Get source activations
with model.trace(source_prompt):
    source_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# Patch each layer: run base, inject source activation
for layer_idx in range(n_layers):
    with model.trace(base_prompt):
        model.transformer.h[layer_idx].output[0][:] = source_hiddens[layer_idx]
        patched_logits = model.lm_head.output.save()

    prob = torch.softmax(patched_logits.value[0, -1], dim=-1)[rome_token]
    direct_effects[layer_idx] = prob.item()

Indirect Effect (Corruption)

Does corrupting a component in source disrupt source behavior?

indirect_effects = torch.zeros(n_layers)

# Get base activations (for corruption)
with model.trace(base_prompt):
    base_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# For each layer: run source, inject base (corrupted) activation
for layer_idx in range(n_layers):
    with model.trace(source_prompt):
        model.transformer.h[layer_idx].output[0][:] = base_hiddens[layer_idx]
        corrupted_logits = model.lm_head.output.save()

    prob = torch.softmax(corrupted_logits.value[0, -1], dim=-1)[rome_token]
    indirect_effects[layer_idx] = source_prob - prob.item()  # Drop from source baseline

Position-Specific Causal Tracing

Identify which token positions carry causal information:

seq_len = len(model.tokenizer.encode(source_prompt))
position_effects = torch.zeros(n_layers, seq_len)

# Get source activations
with model.trace(source_prompt):
    source_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# Patch each layer x position
for layer_idx in range(n_layers):
    for pos_idx in range(seq_len):
        with model.trace(base_prompt):
            # Only patch this specific position
            model.transformer.h[layer_idx].output[0][:, pos_idx, :] = \
                source_hiddens[layer_idx][:, pos_idx, :]
            patched_logits = model.lm_head.output.save()

        prob = torch.softmax(patched_logits.value[0, -1], dim=-1)[rome_token]
        position_effects[layer_idx, pos_idx] = prob.item()

Noising-Based Causal Tracing

Add noise to corrupt, then restore specific components:

def add_noise(activation, noise_level=0.1):
    return activation + noise_level * torch.randn_like(activation)

window_size = 3  # Restore window of layers around target
restoration_effects = torch.zeros(n_layers)

# Clean run - save activations
with model.trace(source_prompt):
    clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# For each layer: noise everything, restore window around this layer
for center_layer in range(n_layers):
    with model.trace(source_prompt):
        for layer_idx, layer in enumerate(model.transformer.h):
            if abs(layer_idx - center_layer) <= window_size // 2:
                # Restore clean
                layer.output[0][:] = clean_hiddens[layer_idx]
            else:
                # Add noise
                layer.output[0][:] = add_noise(layer.output[0])

        restored_logits = model.lm_head.output.save()

    prob = torch.softmax(restored_logits.value[0, -1], dim=-1)[rome_token]
    restoration_effects[center_layer] = prob.item()

MLP vs Attention Decomposition

Separate contributions of MLP and attention:

mlp_effects = torch.zeros(n_layers)
attn_effects = torch.zeros(n_layers)

# Get source MLP and attention outputs
with model.trace(source_prompt):
    source_mlp = [layer.mlp.output[0].save() for layer in model.transformer.h]
    source_attn = [layer.attn.output[0].save() for layer in model.transformer.h]

# Test MLP contributions
for layer_idx in range(n_layers):
    with model.trace(base_prompt):
        model.transformer.h[layer_idx].mlp.output[0][:] = source_mlp[layer_idx]
        mlp_logits = model.lm_head.output.save()
    mlp_effects[layer_idx] = torch.softmax(mlp_logits.value[0, -1], dim=-1)[rome_token]

# Test attention contributions
for layer_idx in range(n_layers):
    with model.trace(base_prompt):
        model.transformer.h[layer_idx].attn.output[0][:] = source_attn[layer_idx]
        attn_logits = model.lm_head.output.save()
    attn_effects[layer_idx] = torch.softmax(attn_logits.value[0, -1], dim=-1)[rome_token]

Visualization

import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Layer-wise effects
axes[0].bar(range(n_layers), direct_effects, alpha=0.7, label='Direct')
axes[0].bar(range(n_layers), indirect_effects, alpha=0.7, label='Indirect')
axes[0].set_xlabel('Layer')
axes[0].set_ylabel('Causal Effect')
axes[0].legend()
axes[0].set_title('Causal Effects by Layer')

# Position x Layer heatmap
input_tokens = model.tokenizer.encode(source_prompt)
token_labels = [model.tokenizer.decode(t) for t in input_tokens]

sns.heatmap(
    position_effects.numpy(),
    ax=axes[1],
    xticklabels=token_labels,
    yticklabels=[f'L{i}' for i in range(n_layers)],
    cmap='viridis'
)
axes[1].set_title('Causal Effect by Position and Layer')
axes[1].set_xlabel('Token Position')
axes[1].set_ylabel('Layer')

plt.tight_layout()

Interpretation Guidelines

  • Early layers + subject position: Often store entity information
  • Middle layers + last subject token: Information extraction/lookup
  • Late layers + final position: Prediction formation
  • High indirect effect: Component is necessary for behavior
  • High direct effect: Component is sufficient to cause behavior