Claude-skill-registry causal-tracing
Causal mediation analysis to identify which model components mediate specific behaviors. Use when investigating how information flows through the network and which neurons or layers are causally responsible for outputs.
git clone https://github.com/majiayu000/claude-skill-registry
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/causal-tracing" ~/.claude/skills/majiayu000-claude-skill-registry-causal-tracing && rm -rf "$T"
skills/data/causal-tracing/SKILL.mdCausal Tracing
Causal tracing (causal mediation analysis) identifies which intermediate computations causally mediate the relationship between inputs and outputs. It reveals not just what correlates with behavior, but what causes it.
Core Concepts
Three Types of Causal Effects
- Total Effect: Change in output when modifying input
- Direct Effect: Effect of restoring a component from clean to corrupted run
- Indirect Effect: Effect of corrupting a component in an otherwise clean run
The Interchange Intervention
Swap activations between two runs to test causal relationships:
- Source run: Produces the activation value
- Base run: Receives the swapped activation
Setup
from nnsight import LanguageModel import torch model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) # Factual recall task base_prompt = "The Eiffel Tower is located in" # Expects: Paris source_prompt = "The Colosseum is located in" # Expects: Rome # Get target tokens paris_token = model.tokenizer(" Paris")["input_ids"][0] rome_token = model.tokenizer(" Rome")["input_ids"][0]
Computing Total Effect
with model.trace() as tracer: with tracer.invoke(base_prompt): base_logits = model.lm_head.output.save() with tracer.invoke(source_prompt): source_logits = model.lm_head.output.save() base_prob = torch.softmax(base_logits.value[0, -1], dim=-1)[paris_token] source_prob = torch.softmax(source_logits.value[0, -1], dim=-1)[rome_token] total_effect = base_prob - source_prob # How much does changing input change output?
Direct Effect (Restoration)
Does restoring a component from source restore source behavior?
n_layers = len(model.transformer.h) direct_effects = torch.zeros(n_layers) # Get source activations with model.trace(source_prompt): source_hiddens = [layer.output[0].save() for layer in model.transformer.h] # Patch each layer: run base, inject source activation for layer_idx in range(n_layers): with model.trace(base_prompt): model.transformer.h[layer_idx].output[0][:] = source_hiddens[layer_idx] patched_logits = model.lm_head.output.save() prob = torch.softmax(patched_logits.value[0, -1], dim=-1)[rome_token] direct_effects[layer_idx] = prob.item()
Indirect Effect (Corruption)
Does corrupting a component in source disrupt source behavior?
indirect_effects = torch.zeros(n_layers) # Get base activations (for corruption) with model.trace(base_prompt): base_hiddens = [layer.output[0].save() for layer in model.transformer.h] # For each layer: run source, inject base (corrupted) activation for layer_idx in range(n_layers): with model.trace(source_prompt): model.transformer.h[layer_idx].output[0][:] = base_hiddens[layer_idx] corrupted_logits = model.lm_head.output.save() prob = torch.softmax(corrupted_logits.value[0, -1], dim=-1)[rome_token] indirect_effects[layer_idx] = source_prob - prob.item() # Drop from source baseline
Position-Specific Causal Tracing
Identify which token positions carry causal information:
seq_len = len(model.tokenizer.encode(source_prompt)) position_effects = torch.zeros(n_layers, seq_len) # Get source activations with model.trace(source_prompt): source_hiddens = [layer.output[0].save() for layer in model.transformer.h] # Patch each layer x position for layer_idx in range(n_layers): for pos_idx in range(seq_len): with model.trace(base_prompt): # Only patch this specific position model.transformer.h[layer_idx].output[0][:, pos_idx, :] = \ source_hiddens[layer_idx][:, pos_idx, :] patched_logits = model.lm_head.output.save() prob = torch.softmax(patched_logits.value[0, -1], dim=-1)[rome_token] position_effects[layer_idx, pos_idx] = prob.item()
Noising-Based Causal Tracing
Add noise to corrupt, then restore specific components:
def add_noise(activation, noise_level=0.1): return activation + noise_level * torch.randn_like(activation) window_size = 3 # Restore window of layers around target restoration_effects = torch.zeros(n_layers) # Clean run - save activations with model.trace(source_prompt): clean_hiddens = [layer.output[0].save() for layer in model.transformer.h] # For each layer: noise everything, restore window around this layer for center_layer in range(n_layers): with model.trace(source_prompt): for layer_idx, layer in enumerate(model.transformer.h): if abs(layer_idx - center_layer) <= window_size // 2: # Restore clean layer.output[0][:] = clean_hiddens[layer_idx] else: # Add noise layer.output[0][:] = add_noise(layer.output[0]) restored_logits = model.lm_head.output.save() prob = torch.softmax(restored_logits.value[0, -1], dim=-1)[rome_token] restoration_effects[center_layer] = prob.item()
MLP vs Attention Decomposition
Separate contributions of MLP and attention:
mlp_effects = torch.zeros(n_layers) attn_effects = torch.zeros(n_layers) # Get source MLP and attention outputs with model.trace(source_prompt): source_mlp = [layer.mlp.output[0].save() for layer in model.transformer.h] source_attn = [layer.attn.output[0].save() for layer in model.transformer.h] # Test MLP contributions for layer_idx in range(n_layers): with model.trace(base_prompt): model.transformer.h[layer_idx].mlp.output[0][:] = source_mlp[layer_idx] mlp_logits = model.lm_head.output.save() mlp_effects[layer_idx] = torch.softmax(mlp_logits.value[0, -1], dim=-1)[rome_token] # Test attention contributions for layer_idx in range(n_layers): with model.trace(base_prompt): model.transformer.h[layer_idx].attn.output[0][:] = source_attn[layer_idx] attn_logits = model.lm_head.output.save() attn_effects[layer_idx] = torch.softmax(attn_logits.value[0, -1], dim=-1)[rome_token]
Visualization
import matplotlib.pyplot as plt import seaborn as sns fig, axes = plt.subplots(1, 2, figsize=(14, 6)) # Layer-wise effects axes[0].bar(range(n_layers), direct_effects, alpha=0.7, label='Direct') axes[0].bar(range(n_layers), indirect_effects, alpha=0.7, label='Indirect') axes[0].set_xlabel('Layer') axes[0].set_ylabel('Causal Effect') axes[0].legend() axes[0].set_title('Causal Effects by Layer') # Position x Layer heatmap input_tokens = model.tokenizer.encode(source_prompt) token_labels = [model.tokenizer.decode(t) for t in input_tokens] sns.heatmap( position_effects.numpy(), ax=axes[1], xticklabels=token_labels, yticklabels=[f'L{i}' for i in range(n_layers)], cmap='viridis' ) axes[1].set_title('Causal Effect by Position and Layer') axes[1].set_xlabel('Token Position') axes[1].set_ylabel('Layer') plt.tight_layout()
Interpretation Guidelines
- Early layers + subject position: Often store entity information
- Middle layers + last subject token: Information extraction/lookup
- Late layers + final position: Prediction formation
- High indirect effect: Component is necessary for behavior
- High direct effect: Component is sufficient to cause behavior