Cc-skills ml-failfast-validation
POC validation patterns to catch issues before committing to long-running ML experiments. TRIGGERS - fail-fast, POC validation, preflight check, experiment validation, schema validation, gradient check, sanity check, smoke test.
git clone https://github.com/terrylica/cc-skills
T=$(mktemp -d) && git clone --depth=1 https://github.com/terrylica/cc-skills "$T" && mkdir -p ~/.claude/skills && cp -r "$T/plugins/devops-tools/skills/ml-failfast-validation" ~/.claude/skills/terrylica-cc-skills-ml-failfast-validation && rm -rf "$T"
plugins/devops-tools/skills/ml-failfast-validation/SKILL.mdML Fail-Fast Validation
POC validation patterns to catch issues before committing to long-running ML experiments.
Self-Evolving Skill: This skill improves through use. If instructions are wrong, parameters drifted, or a workaround was needed — fix this file immediately, don't defer. Only update for real, reproducible issues.
When to Use This Skill
Use this skill when:
- Starting a new ML experiment that will run for hours
- Validating model architecture before full training
- Checking gradient flow and data pipeline integrity
- Implementing POC validation checklists
- Debugging prediction collapse or gradient explosion issues
1. Why Fail-Fast?
| Without Fail-Fast | With Fail-Fast |
|---|---|
| Discover crash 4 hours in | Catch in 30 seconds |
| Debug from cryptic error | Clear error message |
| Lose GPU time | Validate before commit |
| Silent data issues | Explicit schema checks |
Principle: Validate everything that can go wrong BEFORE the expensive computation.
2. POC Validation Checklist
Minimum Viable POC (5 Checks)
def run_poc_validation(): """Fast validation before full experiment.""" print("=" * 60) print("FAIL-FAST POC VALIDATION") print("=" * 60) # [1/5] Model instantiation print("\n[1/5] Model instantiation...") model = create_model(architecture, input_size=n_features) x = torch.randn(32, seq_len, n_features).to(device) out = model(x) assert out.shape == (32, 1), f"Output shape wrong: {out.shape}" print(f" Input: (32, {seq_len}, {n_features}) -> Output: {out.shape}") print(" Status: PASS") # [2/5] Gradient flow print("\n[2/5] Gradient flow...") y = torch.randn(32, 1).to(device) loss = F.mse_loss(out, y) loss.backward() grad_norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None] assert len(grad_norms) > 0, "No gradients!" assert all(np.isfinite(g) for g in grad_norms), "NaN/Inf gradients!" print(f" Max grad norm: {max(grad_norms):.4f}") print(" Status: PASS") # [3/5] NDJSON artifact validation print("\n[3/5] NDJSON artifact validation...") log_path = output_dir / "experiment.jsonl" with open(log_path, "a") as f: f.write(json.dumps({"phase": "poc_start", "timestamp": datetime.now().isoformat()}) + "\n") assert log_path.exists(), "Log file not created" print(f" Log file: {log_path}") print(" Status: PASS") # [4/5] Epoch selector variation print("\n[4/5] Epoch selector variation...") epochs = [] for seed in [1, 2, 3]: selector = create_selector() # Simulate different validation results for e in range(10, 201, 10): selector.record(epoch=e, sortino=np.random.randn() * 0.1, sparsity=np.random.rand()) epochs.append(selector.select()) print(f" Selected epochs: {epochs}") assert len(set(epochs)) > 1 or all(e == epochs[0] for e in epochs), "Selector not varying" print(" Status: PASS") # [5/5] Mini training (10 epochs) print("\n[5/5] Mini training (10 epochs)...") model = create_model(architecture, input_size=n_features).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005) initial_loss = None for epoch in range(10): loss = train_one_epoch(model, train_loader, optimizer) if initial_loss is None: initial_loss = loss print(f" Initial loss: {initial_loss:.4f}") print(f" Final loss: {loss:.4f}") print(" Status: PASS") print("\n" + "=" * 60) print("POC RESULT: ALL 5 CHECKS PASSED") print("=" * 60)
Extended POC (10 Checks)
Add these for comprehensive validation:
# [6/10] Data loading print("\n[6/10] Data loading...") df = fetch_data(symbol, threshold) assert len(df) > min_required_bars, f"Insufficient data: {len(df)} bars" print(f" Loaded: {len(df):,} bars") print(" Status: PASS") # [7/10] Schema validation print("\n[7/10] Schema validation...") validate_schema(df, required_columns, "raw_data") print(" Status: PASS") # [8/10] Feature computation print("\n[8/10] Feature computation...") df = compute_features(df) validate_schema(df, feature_columns, "features") print(f" Features: {len(feature_columns)}") print(" Status: PASS") # [9/10] Prediction sanity print("\n[9/10] Prediction sanity...") preds = model(X_test).detach().cpu().numpy() pred_std = preds.std() target_std = y_test.std() pred_ratio = pred_std / target_std assert pred_ratio > 0.005, f"Predictions collapsed: ratio={pred_ratio:.4f}" print(f" Pred std ratio: {pred_ratio:.2%}") print(" Status: PASS") # [10/10] Checkpoint save/load print("\n[10/10] Checkpoint save/load...") torch.save(model.state_dict(), checkpoint_path) model2 = create_model(architecture, input_size=n_features) model2.load_state_dict(torch.load(checkpoint_path)) print(" Status: PASS")
3. Schema Validation Pattern
The Problem
# BAD: Cryptic error 2 hours into experiment KeyError: 'returns_vs' # Which file? Which function? What columns exist?
The Solution
def validate_schema(df, required: list[str], stage: str) -> None: """Fail-fast schema validation with actionable error messages.""" # Handle both DataFrame columns and DatetimeIndex available = list(df.columns) if hasattr(df.index, 'name') and df.index.name: available.append(df.index.name) missing = [c for c in required if c not in available] if missing: raise ValueError( f"[{stage}] Missing columns: {missing}\n" f"Available: {sorted(available)}\n" f"DataFrame shape: {df.shape}" ) print(f" Schema validation PASSED ({stage}): {len(required)} columns", flush=True) # Usage at pipeline boundaries REQUIRED_RAW = ["open", "high", "low", "close", "volume"] REQUIRED_FEATURES = ["returns_vs", "momentum_z", "atr_pct", "volume_z", "rsi_14", "bb_pct_b", "vol_regime", "return_accel", "pv_divergence"] df = fetch_data(symbol) validate_schema(df, REQUIRED_RAW, "raw_data") df = compute_features(df) validate_schema(df, REQUIRED_FEATURES, "features")
4. Gradient Health Checks
Basic Gradient Check
def check_gradient_health(model: nn.Module, sample_input: torch.Tensor) -> dict: """Verify gradients flow correctly through model.""" model.train() out = model(sample_input) loss = out.sum() loss.backward() stats = {"total_params": 0, "params_with_grad": 0, "grad_norms": []} for name, param in model.named_parameters(): stats["total_params"] += 1 if param.grad is not None: stats["params_with_grad"] += 1 norm = param.grad.norm().item() stats["grad_norms"].append(norm) # Check for issues if not np.isfinite(norm): raise ValueError(f"Non-finite gradient in {name}: {norm}") if norm > 100: print(f" WARNING: Large gradient in {name}: {norm:.2f}") stats["max_grad"] = max(stats["grad_norms"]) if stats["grad_norms"] else 0 stats["mean_grad"] = np.mean(stats["grad_norms"]) if stats["grad_norms"] else 0 return stats
Architecture-Specific Checks
def check_lstm_gradients(model: nn.Module) -> dict: """Check LSTM-specific gradient patterns.""" stats = {} for name, param in model.named_parameters(): if param.grad is None: continue # Check forget gate bias (should not be too negative) if "bias_hh" in name or "bias_ih" in name: # LSTM bias: [i, f, g, o] gates hidden_size = param.shape[0] // 4 forget_bias = param.grad[hidden_size:2*hidden_size] stats["forget_bias_grad_mean"] = forget_bias.mean().item() # Check hidden-to-hidden weights if "weight_hh" in name: stats["hh_weight_grad_norm"] = param.grad.norm().item() return stats
5. Prediction Sanity Checks
Collapse Detection
def check_prediction_sanity(preds: np.ndarray, targets: np.ndarray) -> dict: """Detect prediction collapse or explosion.""" stats = { "pred_mean": preds.mean(), "pred_std": preds.std(), "pred_min": preds.min(), "pred_max": preds.max(), "target_std": targets.std(), } # Relative threshold (not absolute!) stats["pred_std_ratio"] = stats["pred_std"] / stats["target_std"] # Collapse detection if stats["pred_std_ratio"] < 0.005: # < 0.5% of target variance raise ValueError( f"Predictions collapsed!\n" f" pred_std: {stats['pred_std']:.6f}\n" f" target_std: {stats['target_std']:.6f}\n" f" ratio: {stats['pred_std_ratio']:.4%}" ) # Explosion detection if stats["pred_std_ratio"] > 100: # > 100x target variance raise ValueError( f"Predictions exploded!\n" f" pred_std: {stats['pred_std']:.2f}\n" f" target_std: {stats['target_std']:.6f}\n" f" ratio: {stats['pred_std_ratio']:.1f}x" ) # Unique value check stats["unique_values"] = len(np.unique(np.round(preds, 6))) if stats["unique_values"] < 10: print(f" WARNING: Only {stats['unique_values']} unique prediction values") return stats
Correlation Check
def check_prediction_correlation(preds: np.ndarray, targets: np.ndarray) -> float: """Check if predictions have any correlation with targets.""" corr = np.corrcoef(preds.flatten(), targets.flatten())[0, 1] if not np.isfinite(corr): print(" WARNING: Correlation is NaN (likely collapsed predictions)") return 0.0 # Note: negative correlation may still be useful (short signal) print(f" Prediction-target correlation: {corr:.4f}") return corr
6. NDJSON Logging Validation
Required Event Types
REQUIRED_EVENTS = { "experiment_start": ["architecture", "features", "config"], "fold_start": ["fold_id", "train_size", "val_size", "test_size"], "epoch_complete": ["epoch", "train_loss", "val_loss"], "fold_complete": ["fold_id", "test_sharpe", "test_sortino"], "experiment_complete": ["total_folds", "mean_sharpe", "elapsed_seconds"], } def validate_ndjson_schema(log_path: Path) -> None: """Validate NDJSON log has all required events and fields.""" events = {} with open(log_path) as f: for line in f: event = json.loads(line) phase = event.get("phase", "unknown") if phase not in events: events[phase] = [] events[phase].append(event) for phase, required_fields in REQUIRED_EVENTS.items(): if phase not in events: raise ValueError(f"Missing event type: {phase}") sample = events[phase][0] missing = [f for f in required_fields if f not in sample] if missing: raise ValueError(f"Event '{phase}' missing fields: {missing}") print(f" NDJSON schema valid: {len(events)} event types")
7. POC Timing Guide
| Check | Typical Time | Max Time | Action if Exceeded |
|---|---|---|---|
| Model instantiation | < 1s | 5s | Check device, reduce model size |
| Gradient flow | < 2s | 10s | Check batch size |
| Schema validation | < 0.1s | 1s | Check data loading |
| Mini training (10 epochs) | < 30s | 2min | Reduce batch, check data loader |
| Full POC (10 checks) | < 2min | 5min | Something is wrong |
8. Failure Response Guide
| Failure | Likely Cause | Fix |
|---|---|---|
| Shape mismatch | Wrong input_size or seq_len | Check feature count |
| NaN gradients | LR too high, bad init | Reduce LR, check init |
| Zero gradients | Dead layers, missing params | Check model architecture |
| Predictions collapsed | Normalizer issue, bad loss | Check sLSTM normalizer |
| Predictions exploded | Gradient explosion | Add/tighten gradient clipping |
| Schema missing columns | Wrong data source | Check fetch function |
| Checkpoint load fails | State dict key mismatch | Check model architecture match |
9. Integration Example
def main(): # Parse args, setup output dir... # PHASE 1: Fail-fast POC print("=" * 60) print("FAIL-FAST POC VALIDATION") print("=" * 60) try: run_poc_validation() except Exception as e: print(f"\n{'=' * 60}") print(f"POC FAILED: {type(e).__name__}") print(f"{'=' * 60}") print(f"Error: {e}") print("\nFix the issue before running full experiment.") sys.exit(1) # PHASE 2: Full experiment (only if POC passes) print("\n" + "=" * 60) print("STARTING FULL EXPERIMENT") print("=" * 60) run_full_experiment()
10. Anti-Patterns to Avoid
DON'T: Skip validation to "save time"
# BAD: "I'll just run it and see" run_full_experiment() # 4 hours later: crash
DON'T: Use absolute thresholds for relative quantities
# BAD: Absolute threshold assert pred_std > 1e-4 # Meaningless for returns ~0.001 # GOOD: Relative threshold assert pred_std / target_std > 0.005 # 0.5% of target variance
DON'T: Catch all exceptions silently
# BAD: Hides real issues try: result = risky_operation() except Exception: result = default_value # What went wrong? # GOOD: Catch specific exceptions try: result = risky_operation() except (ValueError, RuntimeError) as e: logger.error(f"Operation failed: {e}") raise
DON'T: Print without flush
# BAD: Output buffered, can't see progress print(f"Processing fold {i}...") # GOOD: See output immediately print(f"Processing fold {i}...", flush=True)
References
Troubleshooting
| Issue | Cause | Solution |
|---|---|---|
| NaN gradients in POC | Learning rate too high | Reduce LR by 10x, check weight initialization |
| Zero gradients | Dead layers or missing params | Check model architecture, verify requires_grad=True |
| Predictions collapsed | Normalizer issue or bad loss | Check target normalization, verify loss function |
| Predictions exploded | Gradient explosion | Add gradient clipping, reduce learning rate |
| Schema missing columns | Wrong data source or transform | Verify fetch function returns expected columns |
| Checkpoint load fails | State dict key mismatch | Ensure model architecture matches saved checkpoint |
| POC timeout (>5 min) | Data loading or model too large | Reduce batch size, check DataLoader num_workers |
| Mini training no progress | Learning rate too low or frozen | Increase LR, verify optimizer updates all parameters |
| NDJSON validation fails | Missing required event types | Check all phases emit expected fields |
| Shape mismatch error | Wrong input_size or seq_len | Verify feature count matches model input dimension |
Post-Execution Reflection
After this skill completes, check before closing:
- Did the command succeed? — If not, fix the instruction or error table that caused the failure.
- Did parameters or output change? — If the underlying tool's interface drifted, update Usage examples and Parameters table to match.
- Was a workaround needed? — If you had to improvise (different flags, extra steps), update this SKILL.md so the next invocation doesn't need the same workaround.
Only update if the issue is real and reproducible — not speculative.