Skip to content

Deploy an ML Model with FastAPI

A trained ML model is useless if no one can call it. The standard pattern: wrap the model in a FastAPI endpoint that accepts inputs as JSON and returns predictions.

This is the CampusX tutorial's flagship project — predicting insurance premiums from features like age, BMI, occupation, city tier.

The big picture

1. Train a model offline (scikit-learn, XGBoost, etc.) → save to disk (insurance_model.pkl)
2. FastAPI loads the model once at startup
3. Client POSTs user features → FastAPI runs `model.predict(...)` → returns the prediction
Client → FastAPI → trained model → prediction → response

Step 1 — Train and save a model

For this tutorial we'll fake the training step. In reality, you'd have a notebook that does it. The output is the saved .pkl file.

# train.py  (run once, offline)
import pandas as pd
import pickle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

# Toy dataset
df = pd.DataFrame({
    "age":          [25, 35, 45, 28, 50, 32, 40, 60, 22, 38],
    "bmi":          [22, 28, 30, 24, 33, 26, 29, 27, 21, 31],
    "lifestyle_risk":[1, 2, 3, 1, 3, 2, 2, 3, 1, 3],
    "city_tier":    [1, 2, 3, 1, 2, 1, 2, 3, 1, 2],
    "income_lpa":   [10, 15, 8, 20, 6, 18, 12, 5, 22, 9],
    "premium":      [5000, 8500, 12000, 6000, 14000, 7500, 9500, 16000, 4500, 11000],
})

X = df.drop(columns=["premium"])
y = df["premium"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

print("R² on test:", model.score(X_test, y_test))

# Save to disk
with open("insurance_model.pkl", "wb") as f:
    pickle.dump(model, f)

print("Model saved to insurance_model.pkl")

Run it: python train.py. Produces insurance_model.pkl.

Step 2 — Define the API schema

Inputs and outputs as Pydantic models:

# schemas.py
from pydantic import BaseModel, Field
from typing import Literal

class InsuranceInput(BaseModel):
    age: int = Field(..., ge=0, le=120, examples=[35])
    bmi: float = Field(..., gt=10, lt=60, examples=[24.5])
    lifestyle_risk: Literal[1, 2, 3] = Field(..., description="1=low, 2=med, 3=high")
    city_tier: Literal[1, 2, 3] = Field(..., description="1=metro, 2=mid, 3=small")
    income_lpa: float = Field(..., gt=0, examples=[15.0])

class InsuranceOutput(BaseModel):
    predicted_premium: float
    confidence: float
    risk_band: str

The schemas serve two jobs: 1. Validate every incoming request (Pydantic). 2. Auto-document the API in /docs.

Step 3 — Load the model at startup

# model_loader.py
import pickle
from pathlib import Path

MODEL_PATH = Path(__file__).parent / "insurance_model.pkl"

def load_model():
    with open(MODEL_PATH, "rb") as f:
        return pickle.load(f)

# Loaded once when this module is imported
model = load_model()
print("✅ insurance model loaded")

Step 4 — Build the prediction endpoint

# main.py
import pandas as pd
from fastapi import FastAPI, HTTPException

from schemas import InsuranceInput, InsuranceOutput
from model_loader import model

app = FastAPI(title="Insurance Premium Predictor")

@app.get("/")
def root():
    return {"message": "Insurance Premium API. POST /predict to use."}

@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": model is not None}

@app.post("/predict", response_model=InsuranceOutput)
def predict(features: InsuranceInput):
    try:
        # 1. Build a DataFrame in the EXACT same column order the model was trained on
        X = pd.DataFrame([{
            "age":            features.age,
            "bmi":            features.bmi,
            "lifestyle_risk": features.lifestyle_risk,
            "city_tier":      features.city_tier,
            "income_lpa":     features.income_lpa,
        }])

        # 2. Predict
        pred = model.predict(X)[0]

        # 3. Optional — compute a confidence score (here: a simple proxy)
        # For a RandomForestRegressor we can use the std-dev of trees' predictions
        all_tree_preds = [tree.predict(X)[0] for tree in model.estimators_]
        std = float(pd.Series(all_tree_preds).std())
        confidence = max(0.0, min(1.0, 1 - std / pred))

        # 4. Risk band based on prediction
        if pred < 7000:    band = "Low"
        elif pred < 12000: band = "Medium"
        else:              band = "High"

        return InsuranceOutput(
            predicted_premium=round(float(pred), 2),
            confidence=round(confidence, 3),
            risk_band=band,
        )
    except Exception as e:
        raise HTTPException(500, f"prediction failed: {e}")

Step 5 — Run it

uvicorn main:app --reload

Try it out

Open /docs, expand POST /predict, click Try it out, and send:

{
  "age": 35,
  "bmi": 24.5,
  "lifestyle_risk": 2,
  "city_tier": 1,
  "income_lpa": 15.0
}

Or with curl:

curl -X POST http://127.0.0.1:8000/predict \
  -H "Content-Type: application/json" \
  -d '{
    "age": 35,
    "bmi": 24.5,
    "lifestyle_risk": 2,
    "city_tier": 1,
    "income_lpa": 15.0
  }'

Response:

{
  "predicted_premium": 8400.5,
  "confidence": 0.92,
  "risk_band": "Medium"
}

How it executes end-to-end

1. Client POSTs JSON {age, bmi, ...}
2. uvicorn receives request
3. FastAPI parses + validates against InsuranceInput
4. (auto-rejects if e.g. age < 0 or bmi missing)
5. Your predict() function runs
6. Build DataFrame in correct column order
7. model.predict() runs (in milliseconds — model was pre-loaded)
8. Compute confidence + risk band
9. Return InsuranceOutput → FastAPI serializes to JSON
10. Response sent back to client

The model is loaded once (at server start). Each request is just a fast .predict() call.

Better — pre-load with a lifespan handler

If your model is big and takes time to load, do it explicitly at startup:

from contextlib import asynccontextmanager
from fastapi import FastAPI

model = None       # global placeholder

@asynccontextmanager
async def lifespan(app: FastAPI):
    global model
    print("🔌 loading model...")
    with open("insurance_model.pkl", "rb") as f:
        import pickle
        model = pickle.load(f)
    print("✅ model loaded")
    yield
    print("🔌 shutting down")

app = FastAPI(lifespan=lifespan)

Production checklist for ML APIs

  • Pre-load the model at startup, not per-request.
  • Validate every input with Pydantic — wrong types crash sklearn.
  • Pin the columns and order — X = pd.DataFrame([features.dict()]) may shuffle on Python <3.7.
  • Log every prediction (input + output + latency) for monitoring.
  • Add a /health endpoint that returns 200 only if the model is loaded.
  • Pin every dependency — scikit-learn, pandas, pydantic. Mismatches break pickled models.
  • Save the model with the training environment documented. A model trained with sklearn 1.4 may fail to load in 1.6.
  • Containerize it — see next chapter.
  • Rate-limit the endpoint — prediction calls can be expensive.
  • Monitor drift — track input distribution & accuracy over time.

Common pitfalls

  • Loading the model inside the endpoint — slow first request, repeated for every call. Load once at startup.
  • Different sklearn versions between training and serving — silent prediction errors or load failures. Pin both.
  • Wrong column order — sklearn doesn't always check; it can silently use age as BMI. Build DataFrames explicitly.
  • Unbounded prediction time — some models are slow. Add a timeout or use background tasks for long inference.
  • Using pickle — fragile across Python/sklearn versions. For long-term storage, use joblib, ONNX, or a model serving framework (BentoML, TorchServe, Triton).

What's next

Practice

What does this print?

Expected: True

# ML deployment workflow: train → save model → load in API → serve predictions
steps = ["train", "save", "load", "serve"]
print(len(steps) == 4)

Load the model ONCE at startup, not on every request

Expected: True

# Wrong: loading the model inside the handler — slow on every request
def predict(input):
    model = "load_from_disk_each_time"   # bug: re-loads model per request
    return model
# Right: load at startup, reuse across requests
use_startup_load = False
print(not use_startup_load)

Quiz — Quick check

What you remember

Q1. When should you load the ML model?

  • Once at app startup (via lifespan event or module-level load), reuse across all requests
  • On each request
  • When the user logs in
  • Lazily on first prediction

Why: Loading a model from disk takes seconds. Doing it per request adds that latency to every call. Load once at startup, hold in memory, reuse.

Q2. What's the right way to serialize a scikit-learn model for production?

  • joblib.dump(model, "model.joblib") — efficient for numpy-heavy objects
  • pickle.dump
  • JSON
  • Pandas to_csv

Why: joblib is similar to pickle but more efficient for objects containing large numpy arrays (like trained models). Pickle works too. Avoid loading pickle from untrusted sources — code execution risk.

Q3. How should you handle a slow ML inference call in a FastAPI endpoint?

  • Run it in a thread pool with run_in_executor to avoid blocking the event loop
  • Always async
  • In a subprocess
  • Block the request

Why: ML inference is usually CPU-bound (not I/O). Wrapping in run_in_executor (or using a sync handler that FastAPI runs in its thread pool automatically) keeps the event loop free for other requests.

Common doubts

Where should I store the trained model file?

Local disk works for small models. For larger models or multi-instance deploys, use object storage (S3, GCS, R2) — download at startup. For really large models, mount an EFS/NFS share. Version your model files (e.g., model-v2.joblib).

How do I version a deployed model?

Tag the model file (model-v3.joblib) and the API version. Return the model version in the response. Maintain backward compatibility — old client code shouldn't break when you ship a new model. MLOps tools like MLflow/Weights & Biases formalize this.

What about batching predictions for throughput?

Implement a batch endpoint that accepts a list of inputs and returns a list of predictions. Internally, model.predict on the whole batch is much faster than N calls to model.predict per item — especially for GPU-bound deep models.

Docker & Cloud Deployment