SciAgent-Skills shap-model-explainability
git clone https://github.com/jaechang-hits/SciAgent-Skills
T=$(mktemp -d) && git clone --depth=1 https://github.com/jaechang-hits/SciAgent-Skills "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/biostatistics/shap-model-explainability" ~/.claude/skills/jaechang-hits-sciagent-skills-shap-model-explainability && rm -rf "$T"
skills/biostatistics/shap-model-explainability/SKILL.mdSHAP Model Explainability
Overview
SHAP (SHapley Additive exPlanations) is a unified framework for explaining machine learning model predictions using Shapley values from cooperative game theory. It quantifies each feature's contribution to individual predictions and provides both local (per-instance) and global (dataset-level) explanations with theoretical guarantees of consistency and additivity.
When to Use
- Explaining which features drive a model's predictions (global importance)
- Understanding why a model made a specific prediction (local explanation)
- Debugging model behavior and identifying data leakage
- Analyzing model fairness across demographic groups
- Comparing feature importance across multiple models
- Generating interpretable model explanations for stakeholders
- For tree-based model interpretation, prefer SHAP over permutation importance or Gini importance (more accurate, instance-level)
- For deep learning interpretation on images, consider GradCAM; use SHAP for tabular/structured data
Prerequisites
pip install shap matplotlib # Optional: xgboost lightgbm tensorflow torch (depending on model)
Quick Start
import shap import xgboost as xgb from sklearn.model_selection import train_test_split # Load example data X, y = shap.datasets.adult() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # Train model model = xgb.XGBClassifier(n_estimators=100).fit(X_train, y_train) # Explain: select explainer → compute → visualize explainer = shap.TreeExplainer(model) shap_values = explainer(X_test) shap.plots.beeswarm(shap_values) # Global importance shap.plots.waterfall(shap_values[0]) # Single prediction print(f"Base value: {shap_values.base_values[0]:.3f}") print(f"SHAP values shape: {shap_values.values.shape}") # (n_samples, n_features)
Workflow
Step 1: Select the Right Explainer
Choose based on model type:
| Model Type | Explainer | Speed | Exactness |
|---|---|---|---|
| Tree-based (XGBoost, LightGBM, RF, CatBoost) | | Fast | Exact |
| Linear (LogReg, GLM, Ridge) | | Instant | Exact |
| Deep learning (TensorFlow, PyTorch) | | Fast | Approximate |
| Deep learning (gradient-based) | | Fast | Approximate |
| Any model (black-box) | | Slow | Approximate |
| Any model (permutation-based) | | Very slow | Exact |
| Unsure? | | Auto | Auto |
# Tree-based models (most common) explainer = shap.TreeExplainer(model) # Linear models explainer = shap.LinearExplainer(model, X_train) # Deep learning explainer = shap.DeepExplainer(model, X_train[:100]) # Any model (model-agnostic, slower) explainer = shap.KernelExplainer(model.predict, shap.kmeans(X_train, 50)) # Auto-select explainer = shap.Explainer(model, X_train)
Step 2: Compute SHAP Values
shap_values = explainer(X_test) # shap_values object contains: # .values — SHAP values array (n_samples, n_features) # .base_values — Expected model output (baseline) # .data — Original feature values # Verify additivity: prediction = base_value + sum(SHAP values) print(f" {shap_values.base_values[0]:.3f} + {shap_values.values[0].sum():.3f} = " f"{shap_values.base_values[0] + shap_values.values[0].sum():.3f}")
Step 3: Global Explanations
# Beeswarm: feature importance + value distributions (most informative) shap.plots.beeswarm(shap_values, max_display=15) # Bar: clean mean |SHAP| importance shap.plots.bar(shap_values)
Step 4: Local Explanations (Individual Predictions)
# Waterfall: detailed breakdown of one prediction shap.plots.waterfall(shap_values[0]) # Force: additive force visualization shap.plots.force(shap_values[0])
Step 5: Feature Relationships
# Scatter: how a feature affects predictions shap.plots.scatter(shap_values[:, "Age"]) # Colored by interaction feature shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education-Num"])
Step 6: Advanced Visualizations
# Heatmap: multi-sample SHAP grid shap.plots.heatmap(shap_values[:100]) # Decision plot: cumulative SHAP paths shap.plots.decision(shap_values.base_values[0], shap_values.values[:10], feature_names=X_test.columns.tolist()) # Cohort comparison import numpy as np mask_a = X_test["Age"] < 40 shap.plots.bar({ "Under 40": shap_values[mask_a], "40+": shap_values[~mask_a] })
Key Parameters
| Parameter | Explainer/Function | Default | Effect |
|---|---|---|---|
| TreeExplainer | | for causal interpretation (requires background data) |
| TreeExplainer | | to explain probabilities instead of log-odds |
(background) | KernelExplainer, DeepExplainer | Required | 100-1000 representative samples; use for efficiency |
| KernelExplainer | | Higher = more accurate but slower; minimum 2×features |
| All plot functions | 10 | Number of features shown in plots |
| scatter/beeswarm | 1.0 | Point transparency for dense datasets |
| All plot functions | True | Set to get matplotlib figure for saving |
| beeswarm | None | to cluster correlated features |
Key Concepts
SHAP Value Properties
SHAP values have three theoretical guarantees (unique among explanation methods):
- Additivity:
— exact decompositionprediction = base_value + sum(SHAP values) - Consistency: If a feature becomes more important in the model, its SHAP value increases
- Missingness: Features not present receive zero attribution
Interpretation: Positive SHAP → pushes prediction higher; Negative → lower; Magnitude → strength of impact.
Model Output Types
Understand what your model outputs — SHAP explains the output space:
- Regression: SHAP values in target units (e.g., dollars, temperature)
- Classification (log-odds): Default for tree classifiers. Use
for probability explanationsmodel_output="probability" - Classification (probability): SHAP values sum to probability deviation from baseline
SHAP vs Other Methods
| Method | Local | Global | Consistent | Model-agnostic |
|---|---|---|---|---|
| SHAP | Yes | Yes | Yes | Yes |
| Permutation importance | No | Yes | No | Yes |
| Gini/split importance | No | Yes | No | Trees only |
| LIME | Yes | No | No | Yes |
| Integrated Gradients | Yes | No | Partial | NN only |
Interaction Values (TreeExplainer only)
shap_interaction = explainer.shap_interaction_values(X_test) # Shape: (n_samples, n_features, n_features) # Diagonal = main effects; off-diagonal = pairwise interactions
Background Data Selection
Background data establishes the baseline (expected model output). Selection affects SHAP magnitudes but not relative importance.
- Random sample from training data: 100-500 samples
- Use
for efficient summarizationshap.kmeans(X_train, 50) - For TreeExplainer with
: no background data needed (uses tree structure)tree_path_dependent - For DeepExplainer/KernelExplainer: 100-1000 samples balance accuracy vs speed
Common Recipes
Recipe: Model Debugging
import numpy as np # Find misclassified samples predictions = model.predict(X_test) errors = predictions != y_test error_indices = np.where(errors)[0] # Explain errors for idx in error_indices[:3]: print(f"Sample {idx}: predicted={predictions[idx]}, actual={y_test.iloc[idx]}") shap.plots.waterfall(shap_values[idx]) # Check for data leakage: unexpected high-importance features mean_abs_shap = np.abs(shap_values.values).mean(0) top_features = X_test.columns[mean_abs_shap.argsort()[-5:]] print(f"Top features (check for leakage): {list(top_features)}")
Recipe: Fairness Analysis
# Compare SHAP distributions across groups group_a = shap_values[X_test["Sex"] == 0] group_b = shap_values[X_test["Sex"] == 1] shap.plots.bar({"Female": group_a, "Male": group_b}) # Check protected attribute importance sex_importance = np.abs(shap_values[:, "Sex"].values).mean() total_importance = np.abs(shap_values.values).mean() print(f"Sex contribution: {sex_importance/total_importance:.1%} of total importance")
Recipe: Production Caching
import joblib # Save explainer for reuse joblib.dump(explainer, 'explainer.pkl') explainer = joblib.load('explainer.pkl') # Batch computation for API responses def explain_batch(X_batch, explainer, top_n=5): sv = explainer(X_batch) results = [] for i in range(len(X_batch)): top_idx = np.abs(sv.values[i]).argsort()[-top_n:] results.append({ 'prediction': sv.base_values[i] + sv.values[i].sum(), 'top_features': {X_batch.columns[j]: sv.values[i][j] for j in top_idx} }) return results
Recipe: MLflow Integration
import mlflow import matplotlib.pyplot as plt with mlflow.start_run(): model = xgb.XGBClassifier().fit(X_train, y_train) explainer = shap.TreeExplainer(model) shap_values = explainer(X_test) shap.plots.beeswarm(shap_values, show=False) mlflow.log_figure(plt.gcf(), "shap_beeswarm.png") plt.close() for feat, imp in zip(X_test.columns, np.abs(shap_values.values).mean(0)): mlflow.log_metric(f"shap_{feat}", imp)
Expected Outputs
| Output | Type | Description |
|---|---|---|
| | Object with , (baseline), (input features) |
| Waterfall plot | matplotlib figure | Single-instance explanation showing feature contributions from base value to prediction |
| Beeswarm plot | matplotlib figure | Global summary: feature importance × direction for all samples |
| Bar plot | matplotlib figure | Mean absolute SHAP values per feature (global importance ranking) |
| Force plot | HTML/matplotlib | Interactive or static visualization of a single prediction |
| | Per-feature mean absolute SHAP value for ranking and reporting |
Troubleshooting
| Problem | Cause | Solution |
|---|---|---|
| Very slow computation | Using KernelExplainer for tree model | Use for tree-based models |
| Slow on large dataset | Computing all samples at once | Sample subset: or batch |
| SHAP values don't sum to prediction | Wrong model output type | Check parameter; verify additivity |
| Log-odds vs probability confusion | Tree classifier defaults to log-odds | Use |
| Plots too cluttered | Too many features shown | Set or use feature clustering |
| DeepExplainer error | Background data too small | Use 100-1000 background samples |
| Memory error | Large dataset + many features | Reduce background data with |
| Force plot not rendering | Missing JS in notebook | Run at notebook start |
| Inconsistent importance across runs | KernelExplainer sampling variance | Increase or use deterministic explainer |
| Negative importance for relevant feature | Feature interactions or correlations | Use or scatter plots |
Bundled Resources
— Mathematical foundations: Shapley value formula, key properties (additivity, symmetry, dummy, monotonicity), computation algorithms (Tree SHAP, Kernel SHAP, Deep SHAP, Linear SHAP), conditional expectations (interventional vs observational), comparison with LIME/DeepLIFT/LRP/Integrated Gradients, interaction values, theoretical limitationsreferences/theory.md
Not migrated from original:
references/explainers.md (340 lines) — detailed constructor parameters, methods, and performance benchmarks for each explainer class. Explainer selection guide and common usage are covered inline in Workflow Step 1 and Key Parameters.
Not migrated from original:
references/plots.md (508 lines) — comprehensive parameter reference for all 9 plot types with advanced customization (violin, decision, feature clustering). Main plot types are covered inline in Workflow Steps 3-6.
Not migrated from original:
references/workflows.md (606 lines) — detailed step-by-step workflows for feature engineering, model comparison, deep learning explanation, production deployment, and time series. Core patterns are covered in Common Recipes; consult original for extended workflows.
Best Practices
- Choose specialized explainers first —
>TreeExplainer
>LinearExplainer
>DeepExplainer
. Only use model-agnostic explainers when no specialized one existsKernelExplainer - Start global, then go local — begin with beeswarm/bar for overall importance, then waterfall/scatter for individual predictions and feature relationships
- Use multiple visualizations — different plots reveal different insights; combine global (beeswarm) + local (waterfall) + relationship (scatter)
- Select appropriate background data — 100-500 representative samples from training data; use
for efficiencyshap.kmeans() - Validate with domain knowledge — unexpectedly high feature importance may indicate data leakage, not true predictive power
- Remember SHAP shows association, not causation — a feature's high SHAP importance means the model uses it, not that it causally affects the outcome
- Consider feature correlations — correlated features share SHAP importance; use
for causal interpretation or feature clustering for grouped importancefeature_perturbation="interventional"
References
- Lundberg & Lee (2017). "A Unified Approach to Interpreting Model Predictions" (NeurIPS)
- Lundberg et al. (2020). "From local explanations to global understanding with explainable AI for trees" (Nature Machine Intelligence)
- Official documentation: https://shap.readthedocs.io/
- GitHub: https://github.com/shap/shap
Related Skills
- scikit-learn-machine-learning — model training for SHAP analysis
- matplotlib-scientific-plotting — custom SHAP plot styling and export
- statistical-analysis — statistical testing to complement SHAP interpretation