SciAgent-Skills shap-model-explainability

install
source · Clone the upstream repo
git clone https://github.com/jaechang-hits/SciAgent-Skills
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/jaechang-hits/SciAgent-Skills "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/biostatistics/shap-model-explainability" ~/.claude/skills/jaechang-hits-sciagent-skills-shap-model-explainability && rm -rf "$T"
manifest: skills/biostatistics/shap-model-explainability/SKILL.md
source content

SHAP Model Explainability

Overview

SHAP (SHapley Additive exPlanations) is a unified framework for explaining machine learning model predictions using Shapley values from cooperative game theory. It quantifies each feature's contribution to individual predictions and provides both local (per-instance) and global (dataset-level) explanations with theoretical guarantees of consistency and additivity.

When to Use

  • Explaining which features drive a model's predictions (global importance)
  • Understanding why a model made a specific prediction (local explanation)
  • Debugging model behavior and identifying data leakage
  • Analyzing model fairness across demographic groups
  • Comparing feature importance across multiple models
  • Generating interpretable model explanations for stakeholders
  • For tree-based model interpretation, prefer SHAP over permutation importance or Gini importance (more accurate, instance-level)
  • For deep learning interpretation on images, consider GradCAM; use SHAP for tabular/structured data

Prerequisites

pip install shap matplotlib
# Optional: xgboost lightgbm tensorflow torch (depending on model)

Quick Start

import shap
import xgboost as xgb
from sklearn.model_selection import train_test_split

# Load example data
X, y = shap.datasets.adult()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = xgb.XGBClassifier(n_estimators=100).fit(X_train, y_train)

# Explain: select explainer → compute → visualize
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)

shap.plots.beeswarm(shap_values)   # Global importance
shap.plots.waterfall(shap_values[0])  # Single prediction
print(f"Base value: {shap_values.base_values[0]:.3f}")
print(f"SHAP values shape: {shap_values.values.shape}")  # (n_samples, n_features)

Workflow

Step 1: Select the Right Explainer

Choose based on model type:

Model TypeExplainerSpeedExactness
Tree-based (XGBoost, LightGBM, RF, CatBoost)
TreeExplainer
FastExact
Linear (LogReg, GLM, Ridge)
LinearExplainer
InstantExact
Deep learning (TensorFlow, PyTorch)
DeepExplainer
FastApproximate
Deep learning (gradient-based)
GradientExplainer
FastApproximate
Any model (black-box)
KernelExplainer
SlowApproximate
Any model (permutation-based)
PermutationExplainer
Very slowExact
Unsure?
shap.Explainer
AutoAuto
# Tree-based models (most common)
explainer = shap.TreeExplainer(model)

# Linear models
explainer = shap.LinearExplainer(model, X_train)

# Deep learning
explainer = shap.DeepExplainer(model, X_train[:100])

# Any model (model-agnostic, slower)
explainer = shap.KernelExplainer(model.predict, shap.kmeans(X_train, 50))

# Auto-select
explainer = shap.Explainer(model, X_train)

Step 2: Compute SHAP Values

shap_values = explainer(X_test)

# shap_values object contains:
# .values      — SHAP values array (n_samples, n_features)
# .base_values — Expected model output (baseline)
# .data        — Original feature values

# Verify additivity: prediction = base_value + sum(SHAP values)
print(f"  {shap_values.base_values[0]:.3f} + {shap_values.values[0].sum():.3f} = "
      f"{shap_values.base_values[0] + shap_values.values[0].sum():.3f}")

Step 3: Global Explanations

# Beeswarm: feature importance + value distributions (most informative)
shap.plots.beeswarm(shap_values, max_display=15)

# Bar: clean mean |SHAP| importance
shap.plots.bar(shap_values)

Step 4: Local Explanations (Individual Predictions)

# Waterfall: detailed breakdown of one prediction
shap.plots.waterfall(shap_values[0])

# Force: additive force visualization
shap.plots.force(shap_values[0])

Step 5: Feature Relationships

# Scatter: how a feature affects predictions
shap.plots.scatter(shap_values[:, "Age"])

# Colored by interaction feature
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education-Num"])

Step 6: Advanced Visualizations

# Heatmap: multi-sample SHAP grid
shap.plots.heatmap(shap_values[:100])

# Decision plot: cumulative SHAP paths
shap.plots.decision(shap_values.base_values[0], shap_values.values[:10],
                     feature_names=X_test.columns.tolist())

# Cohort comparison
import numpy as np
mask_a = X_test["Age"] < 40
shap.plots.bar({
    "Under 40": shap_values[mask_a],
    "40+": shap_values[~mask_a]
})

Key Parameters

ParameterExplainer/FunctionDefaultEffect
feature_perturbation
TreeExplainer
"tree_path_dependent"
"interventional"
for causal interpretation (requires background data)
model_output
TreeExplainer
"raw"
"probability"
to explain probabilities instead of log-odds
data
(background)
KernelExplainer, DeepExplainerRequired100-1000 representative samples; use
shap.kmeans(X, 50)
for efficiency
nsamples
KernelExplainer
"auto"
Higher = more accurate but slower; minimum 2×features
max_display
All plot functions10Number of features shown in plots
alpha
scatter/beeswarm1.0Point transparency for dense datasets
show
All plot functionsTrueSet
False
to get matplotlib figure for saving
clustering
beeswarmNone
shap.utils.hclust(...)
to cluster correlated features

Key Concepts

SHAP Value Properties

SHAP values have three theoretical guarantees (unique among explanation methods):

  • Additivity:
    prediction = base_value + sum(SHAP values)
    — exact decomposition
  • Consistency: If a feature becomes more important in the model, its SHAP value increases
  • Missingness: Features not present receive zero attribution

Interpretation: Positive SHAP → pushes prediction higher; Negative → lower; Magnitude → strength of impact.

Model Output Types

Understand what your model outputs — SHAP explains the output space:

  • Regression: SHAP values in target units (e.g., dollars, temperature)
  • Classification (log-odds): Default for tree classifiers. Use
    model_output="probability"
    for probability explanations
  • Classification (probability): SHAP values sum to probability deviation from baseline

SHAP vs Other Methods

MethodLocalGlobalConsistentModel-agnostic
SHAPYesYesYesYes
Permutation importanceNoYesNoYes
Gini/split importanceNoYesNoTrees only
LIMEYesNoNoYes
Integrated GradientsYesNoPartialNN only

Interaction Values (TreeExplainer only)

shap_interaction = explainer.shap_interaction_values(X_test)
# Shape: (n_samples, n_features, n_features)
# Diagonal = main effects; off-diagonal = pairwise interactions

Background Data Selection

Background data establishes the baseline (expected model output). Selection affects SHAP magnitudes but not relative importance.

  • Random sample from training data: 100-500 samples
  • Use
    shap.kmeans(X_train, 50)
    for efficient summarization
  • For TreeExplainer with
    tree_path_dependent
    : no background data needed (uses tree structure)
  • For DeepExplainer/KernelExplainer: 100-1000 samples balance accuracy vs speed

Common Recipes

Recipe: Model Debugging

import numpy as np

# Find misclassified samples
predictions = model.predict(X_test)
errors = predictions != y_test
error_indices = np.where(errors)[0]

# Explain errors
for idx in error_indices[:3]:
    print(f"Sample {idx}: predicted={predictions[idx]}, actual={y_test.iloc[idx]}")
    shap.plots.waterfall(shap_values[idx])

# Check for data leakage: unexpected high-importance features
mean_abs_shap = np.abs(shap_values.values).mean(0)
top_features = X_test.columns[mean_abs_shap.argsort()[-5:]]
print(f"Top features (check for leakage): {list(top_features)}")

Recipe: Fairness Analysis

# Compare SHAP distributions across groups
group_a = shap_values[X_test["Sex"] == 0]
group_b = shap_values[X_test["Sex"] == 1]

shap.plots.bar({"Female": group_a, "Male": group_b})

# Check protected attribute importance
sex_importance = np.abs(shap_values[:, "Sex"].values).mean()
total_importance = np.abs(shap_values.values).mean()
print(f"Sex contribution: {sex_importance/total_importance:.1%} of total importance")

Recipe: Production Caching

import joblib

# Save explainer for reuse
joblib.dump(explainer, 'explainer.pkl')
explainer = joblib.load('explainer.pkl')

# Batch computation for API responses
def explain_batch(X_batch, explainer, top_n=5):
    sv = explainer(X_batch)
    results = []
    for i in range(len(X_batch)):
        top_idx = np.abs(sv.values[i]).argsort()[-top_n:]
        results.append({
            'prediction': sv.base_values[i] + sv.values[i].sum(),
            'top_features': {X_batch.columns[j]: sv.values[i][j] for j in top_idx}
        })
    return results

Recipe: MLflow Integration

import mlflow
import matplotlib.pyplot as plt

with mlflow.start_run():
    model = xgb.XGBClassifier().fit(X_train, y_train)
    explainer = shap.TreeExplainer(model)
    shap_values = explainer(X_test)

    shap.plots.beeswarm(shap_values, show=False)
    mlflow.log_figure(plt.gcf(), "shap_beeswarm.png")
    plt.close()

    for feat, imp in zip(X_test.columns, np.abs(shap_values.values).mean(0)):
        mlflow.log_metric(f"shap_{feat}", imp)

Expected Outputs

OutputTypeDescription
shap_values
shap.Explanation
Object with
.values
(n_samples, n_features)
,
.base_values
(baseline),
.data
(input features)
Waterfall plotmatplotlib figureSingle-instance explanation showing feature contributions from base value to prediction
Beeswarm plotmatplotlib figureGlobal summary: feature importance × direction for all samples
Bar plotmatplotlib figureMean absolute SHAP values per feature (global importance ranking)
Force plotHTML/matplotlibInteractive or static visualization of a single prediction
mean_abs_shap
pd.Series
Per-feature mean absolute SHAP value for ranking and reporting

Troubleshooting

ProblemCauseSolution
Very slow computationUsing KernelExplainer for tree modelUse
TreeExplainer
for tree-based models
Slow on large datasetComputing all samples at onceSample subset:
explainer(X_test[:1000])
or batch
SHAP values don't sum to predictionWrong model output typeCheck
model_output
parameter; verify additivity
Log-odds vs probability confusionTree classifier defaults to log-oddsUse
TreeExplainer(model, model_output="probability")
Plots too clutteredToo many features shownSet
max_display=10
or use feature clustering
DeepExplainer errorBackground data too smallUse 100-1000 background samples
Memory errorLarge dataset + many featuresReduce background data with
shap.kmeans(X, 50)
Force plot not renderingMissing JS in notebookRun
shap.initjs()
at notebook start
Inconsistent importance across runsKernelExplainer sampling varianceIncrease
nsamples
or use deterministic explainer
Negative importance for relevant featureFeature interactions or correlationsUse
feature_perturbation="interventional"
or scatter plots

Bundled Resources

  • references/theory.md
    — Mathematical foundations: Shapley value formula, key properties (additivity, symmetry, dummy, monotonicity), computation algorithms (Tree SHAP, Kernel SHAP, Deep SHAP, Linear SHAP), conditional expectations (interventional vs observational), comparison with LIME/DeepLIFT/LRP/Integrated Gradients, interaction values, theoretical limitations

Not migrated from original:

references/explainers.md
(340 lines) — detailed constructor parameters, methods, and performance benchmarks for each explainer class. Explainer selection guide and common usage are covered inline in Workflow Step 1 and Key Parameters.

Not migrated from original:

references/plots.md
(508 lines) — comprehensive parameter reference for all 9 plot types with advanced customization (violin, decision, feature clustering). Main plot types are covered inline in Workflow Steps 3-6.

Not migrated from original:

references/workflows.md
(606 lines) — detailed step-by-step workflows for feature engineering, model comparison, deep learning explanation, production deployment, and time series. Core patterns are covered in Common Recipes; consult original for extended workflows.

Best Practices

  1. Choose specialized explainers first
    TreeExplainer
    >
    LinearExplainer
    >
    DeepExplainer
    >
    KernelExplainer
    . Only use model-agnostic explainers when no specialized one exists
  2. Start global, then go local — begin with beeswarm/bar for overall importance, then waterfall/scatter for individual predictions and feature relationships
  3. Use multiple visualizations — different plots reveal different insights; combine global (beeswarm) + local (waterfall) + relationship (scatter)
  4. Select appropriate background data — 100-500 representative samples from training data; use
    shap.kmeans()
    for efficiency
  5. Validate with domain knowledge — unexpectedly high feature importance may indicate data leakage, not true predictive power
  6. Remember SHAP shows association, not causation — a feature's high SHAP importance means the model uses it, not that it causally affects the outcome
  7. Consider feature correlations — correlated features share SHAP importance; use
    feature_perturbation="interventional"
    for causal interpretation or feature clustering for grouped importance

References

  • Lundberg & Lee (2017). "A Unified Approach to Interpreting Model Predictions" (NeurIPS)
  • Lundberg et al. (2020). "From local explanations to global understanding with explainable AI for trees" (Nature Machine Intelligence)
  • Official documentation: https://shap.readthedocs.io/
  • GitHub: https://github.com/shap/shap

Related Skills

  • scikit-learn-machine-learning — model training for SHAP analysis
  • matplotlib-scientific-plotting — custom SHAP plot styling and export
  • statistical-analysis — statistical testing to complement SHAP interpretation