Claude-skill-registry-data MLflow Patterns
ML experiment tracking, model registry, and deployment with MLflow for reproducible machine learning workflows.
install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry-data
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry-data "$T" && mkdir -p ~/.claude/skills && cp -r "$T/data/mlflow-patterns" ~/.claude/skills/majiayu000-claude-skill-registry-data-mlflow-patterns && rm -rf "$T"
manifest:
data/mlflow-patterns/SKILL.mdsource content
MLflow Patterns
Overview
MLflow เป็น open-source platform สำหรับ managing ML lifecycle ครอบคลุม experiment tracking, model packaging, model registry, และ deployment ช่วยให้ทีม data science ทำงานร่วมกันและ deploy models ได้อย่าง reproducible
Why This Matters
- Reproducibility: Track experiments และ reproduce results
- Collaboration: Share experiments และ models across team
- Deployment: Package และ deploy models consistently
- Governance: Model versioning และ approval workflow
Core Concepts
1. Experiment Tracking
import mlflow from mlflow.tracking import MlflowClient # Set tracking URI mlflow.set_tracking_uri("http://mlflow-server:5000") mlflow.set_experiment("customer-churn-prediction") # Start run with auto-logging mlflow.sklearn.autolog() with mlflow.start_run(run_name="xgboost-v1") as run: # Log parameters mlflow.log_params({ "learning_rate": 0.1, "max_depth": 6, "n_estimators": 100, "subsample": 0.8, }) # Train model model = XGBClassifier( learning_rate=0.1, max_depth=6, n_estimators=100, subsample=0.8, ) model.fit(X_train, y_train) # Log metrics y_pred = model.predict(X_test) mlflow.log_metrics({ "accuracy": accuracy_score(y_test, y_pred), "precision": precision_score(y_test, y_pred), "recall": recall_score(y_test, y_pred), "f1": f1_score(y_test, y_pred), "auc_roc": roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]), }) # Log artifacts mlflow.log_artifact("feature_importance.png") mlflow.log_artifact("confusion_matrix.png") # Log model mlflow.sklearn.log_model( model, artifact_path="model", registered_model_name="churn-prediction-model", ) # Log dataset info mlflow.log_input( mlflow.data.from_pandas(X_train, source="s3://data/train.parquet"), context="training" ) print(f"Run ID: {run.info.run_id}")
2. Custom Model Wrapper
import mlflow.pyfunc class ChurnModelWrapper(mlflow.pyfunc.PythonModel): """Custom model wrapper with preprocessing""" def load_context(self, context): """Load model and artifacts""" import joblib self.model = joblib.load(context.artifacts["model"]) self.preprocessor = joblib.load(context.artifacts["preprocessor"]) self.feature_names = context.artifacts["feature_names"] def predict(self, context, model_input): """Predict with preprocessing""" # Validate input if not all(col in model_input.columns for col in self.feature_names): raise ValueError(f"Missing required features: {self.feature_names}") # Preprocess processed = self.preprocessor.transform(model_input[self.feature_names]) # Predict with probability predictions = self.model.predict_proba(processed)[:, 1] return pd.DataFrame({ "churn_probability": predictions, "churn_prediction": (predictions > 0.5).astype(int), }) # Log custom model with mlflow.start_run(): artifacts = { "model": "model.joblib", "preprocessor": "preprocessor.joblib", "feature_names": "features.json", } mlflow.pyfunc.log_model( artifact_path="model", python_model=ChurnModelWrapper(), artifacts=artifacts, conda_env={ "dependencies": [ "python=3.10", "scikit-learn=1.3.0", "xgboost=2.0.0", "pandas=2.0.0", ] }, signature=mlflow.models.infer_signature(X_test, predictions), input_example=X_test.head(5), )
3. Model Registry
from mlflow.tracking import MlflowClient client = MlflowClient() # Register model from run model_uri = f"runs:/{run_id}/model" model_version = mlflow.register_model(model_uri, "churn-prediction-model") # Add description and tags client.update_model_version( name="churn-prediction-model", version=model_version.version, description="XGBoost model trained on Q4 2024 data" ) client.set_model_version_tag( name="churn-prediction-model", version=model_version.version, key="validation_status", value="pending" ) # Transition to staging (after validation) client.transition_model_version_stage( name="churn-prediction-model", version=model_version.version, stage="Staging", archive_existing_versions=False ) # Promote to production (after approval) client.transition_model_version_stage( name="churn-prediction-model", version=model_version.version, stage="Production", archive_existing_versions=True # Archive old production version ) # Load production model model = mlflow.pyfunc.load_model("models:/churn-prediction-model/Production") predictions = model.predict(new_data)
4. Model Validation Pipeline
# validation/validate_model.py import mlflow from mlflow.tracking import MlflowClient def validate_model(model_name: str, version: str) -> bool: """Validate model before promotion""" client = MlflowClient() model_uri = f"models:/{model_name}/{version}" # Load model model = mlflow.pyfunc.load_model(model_uri) # Load validation dataset val_data = pd.read_parquet("s3://data/validation.parquet") X_val, y_val = val_data.drop("target", axis=1), val_data["target"] # Run predictions predictions = model.predict(X_val) # Calculate metrics metrics = { "val_accuracy": accuracy_score(y_val, predictions["churn_prediction"]), "val_auc": roc_auc_score(y_val, predictions["churn_probability"]), } # Get production model metrics (if exists) try: prod_model = mlflow.pyfunc.load_model(f"models:/{model_name}/Production") prod_predictions = prod_model.predict(X_val) prod_metrics = { "prod_accuracy": accuracy_score(y_val, prod_predictions["churn_prediction"]), "prod_auc": roc_auc_score(y_val, prod_predictions["churn_probability"]), } except: prod_metrics = {"prod_accuracy": 0, "prod_auc": 0} # Validation rules validations = [ ("accuracy_threshold", metrics["val_accuracy"] >= 0.85), ("auc_threshold", metrics["val_auc"] >= 0.80), ("accuracy_improvement", metrics["val_accuracy"] >= prod_metrics["prod_accuracy"]), ("auc_improvement", metrics["val_auc"] >= prod_metrics["prod_auc"] - 0.01), # Allow 1% drop ] # Log validation results with mlflow.start_run(run_name=f"validation-{model_name}-v{version}"): mlflow.log_metrics(metrics) mlflow.log_metrics(prod_metrics) for name, passed in validations: mlflow.log_metric(f"validation_{name}", int(passed)) # Update model tags all_passed = all(passed for _, passed in validations) client.set_model_version_tag( name=model_name, version=version, key="validation_status", value="passed" if all_passed else "failed" ) return all_passed
5. Model Serving
# serve/model_server.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel import mlflow app = FastAPI() # Load model at startup MODEL_NAME = "churn-prediction-model" MODEL_STAGE = "Production" model = None @app.on_event("startup") async def load_model(): global model model = mlflow.pyfunc.load_model(f"models:/{MODEL_NAME}/{MODEL_STAGE}") class PredictionRequest(BaseModel): features: dict class PredictionResponse(BaseModel): churn_probability: float churn_prediction: int model_version: str @app.post("/predict", response_model=PredictionResponse) async def predict(request: PredictionRequest): try: input_df = pd.DataFrame([request.features]) predictions = model.predict(input_df) return PredictionResponse( churn_probability=float(predictions["churn_probability"].iloc[0]), churn_prediction=int(predictions["churn_prediction"].iloc[0]), model_version=model.metadata.run_id, ) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/health") async def health(): return {"status": "healthy", "model_loaded": model is not None} # Or use MLflow's built-in serving # mlflow models serve -m "models:/churn-prediction-model/Production" -p 5001
Quick Start
-
Install MLflow:
pip install mlflow -
Start tracking server:
mlflow server --backend-store-uri sqlite:///mlflow.db \ --default-artifact-root s3://mlflow-artifacts \ --host 0.0.0.0 -
Set tracking URI in code:
mlflow.set_tracking_uri("http://localhost:5000") -
Run experiment:
with mlflow.start_run(): mlflow.log_param("param", value) mlflow.log_metric("metric", value) mlflow.sklearn.log_model(model, "model") -
View in UI: Open http://localhost:5000
Production Checklist
- Tracking server with persistent backend
- Artifact storage (S3/GCS/Azure Blob)
- Authentication enabled
- Model signature defined
- Input examples logged
- Conda/pip environment specified
- Validation pipeline configured
- Model approval workflow
- Monitoring for model drift
Anti-patterns
- No Experiment Naming: Use meaningful experiment/run names
- Skipping Signatures: Always define model signatures
- Manual Promotion: Use validation pipeline for stage transitions
- Missing Environment: Always specify dependencies
Integration Points
- Storage: S3, GCS, Azure Blob, HDFS
- Databases: PostgreSQL, MySQL for backend store
- Orchestration: Airflow, Prefect, Dagster
- Serving: SageMaker, Kubernetes, Azure ML