Deploy an ML Model with FastAPI¶
A trained ML model is useless if no one can call it. The standard pattern: wrap the model in a FastAPI endpoint that accepts inputs as JSON and returns predictions.
This is the CampusX tutorial's flagship project — predicting insurance premiums from features like age, BMI, occupation, city tier.
The big picture¶
1. Train a model offline (scikit-learn, XGBoost, etc.) → save to disk (insurance_model.pkl)
↓
2. FastAPI loads the model once at startup
↓
3. Client POSTs user features → FastAPI runs `model.predict(...)` → returns the prediction
Step 1 — Train and save a model¶
For this tutorial we'll fake the training step. In reality, you'd have a notebook that does it. The output is the saved .pkl file.
# train.py (run once, offline)
import pandas as pd
import pickle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
# Toy dataset
df = pd.DataFrame({
"age": [25, 35, 45, 28, 50, 32, 40, 60, 22, 38],
"bmi": [22, 28, 30, 24, 33, 26, 29, 27, 21, 31],
"lifestyle_risk":[1, 2, 3, 1, 3, 2, 2, 3, 1, 3],
"city_tier": [1, 2, 3, 1, 2, 1, 2, 3, 1, 2],
"income_lpa": [10, 15, 8, 20, 6, 18, 12, 5, 22, 9],
"premium": [5000, 8500, 12000, 6000, 14000, 7500, 9500, 16000, 4500, 11000],
})
X = df.drop(columns=["premium"])
y = df["premium"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
print("R² on test:", model.score(X_test, y_test))
# Save to disk
with open("insurance_model.pkl", "wb") as f:
pickle.dump(model, f)
print("Model saved to insurance_model.pkl")
Run it: python train.py. Produces insurance_model.pkl.
Step 2 — Define the API schema¶
Inputs and outputs as Pydantic models:
# schemas.py
from pydantic import BaseModel, Field
from typing import Literal
class InsuranceInput(BaseModel):
age: int = Field(..., ge=0, le=120, examples=[35])
bmi: float = Field(..., gt=10, lt=60, examples=[24.5])
lifestyle_risk: Literal[1, 2, 3] = Field(..., description="1=low, 2=med, 3=high")
city_tier: Literal[1, 2, 3] = Field(..., description="1=metro, 2=mid, 3=small")
income_lpa: float = Field(..., gt=0, examples=[15.0])
class InsuranceOutput(BaseModel):
predicted_premium: float
confidence: float
risk_band: str
The schemas serve two jobs:
1. Validate every incoming request (Pydantic).
2. Auto-document the API in /docs.
Step 3 — Load the model at startup¶
# model_loader.py
import pickle
from pathlib import Path
MODEL_PATH = Path(__file__).parent / "insurance_model.pkl"
def load_model():
with open(MODEL_PATH, "rb") as f:
return pickle.load(f)
# Loaded once when this module is imported
model = load_model()
print("✅ insurance model loaded")
Step 4 — Build the prediction endpoint¶
# main.py
import pandas as pd
from fastapi import FastAPI, HTTPException
from schemas import InsuranceInput, InsuranceOutput
from model_loader import model
app = FastAPI(title="Insurance Premium Predictor")
@app.get("/")
def root():
return {"message": "Insurance Premium API. POST /predict to use."}
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": model is not None}
@app.post("/predict", response_model=InsuranceOutput)
def predict(features: InsuranceInput):
try:
# 1. Build a DataFrame in the EXACT same column order the model was trained on
X = pd.DataFrame([{
"age": features.age,
"bmi": features.bmi,
"lifestyle_risk": features.lifestyle_risk,
"city_tier": features.city_tier,
"income_lpa": features.income_lpa,
}])
# 2. Predict
pred = model.predict(X)[0]
# 3. Optional — compute a confidence score (here: a simple proxy)
# For a RandomForestRegressor we can use the std-dev of trees' predictions
all_tree_preds = [tree.predict(X)[0] for tree in model.estimators_]
std = float(pd.Series(all_tree_preds).std())
confidence = max(0.0, min(1.0, 1 - std / pred))
# 4. Risk band based on prediction
if pred < 7000: band = "Low"
elif pred < 12000: band = "Medium"
else: band = "High"
return InsuranceOutput(
predicted_premium=round(float(pred), 2),
confidence=round(confidence, 3),
risk_band=band,
)
except Exception as e:
raise HTTPException(500, f"prediction failed: {e}")
Step 5 — Run it¶
Try it out¶
Open /docs, expand POST /predict, click Try it out, and send:
Or with curl:
curl -X POST http://127.0.0.1:8000/predict \
-H "Content-Type: application/json" \
-d '{
"age": 35,
"bmi": 24.5,
"lifestyle_risk": 2,
"city_tier": 1,
"income_lpa": 15.0
}'
Response:
How it executes end-to-end¶
1. Client POSTs JSON {age, bmi, ...}
↓
2. uvicorn receives request
↓
3. FastAPI parses + validates against InsuranceInput
↓
4. (auto-rejects if e.g. age < 0 or bmi missing)
↓
5. Your predict() function runs
↓
6. Build DataFrame in correct column order
↓
7. model.predict() runs (in milliseconds — model was pre-loaded)
↓
8. Compute confidence + risk band
↓
9. Return InsuranceOutput → FastAPI serializes to JSON
↓
10. Response sent back to client
The model is loaded once (at server start). Each request is just a fast .predict() call.
Better — pre-load with a lifespan handler¶
If your model is big and takes time to load, do it explicitly at startup:
from contextlib import asynccontextmanager
from fastapi import FastAPI
model = None # global placeholder
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
print("🔌 loading model...")
with open("insurance_model.pkl", "rb") as f:
import pickle
model = pickle.load(f)
print("✅ model loaded")
yield
print("🔌 shutting down")
app = FastAPI(lifespan=lifespan)
Production checklist for ML APIs¶
- Pre-load the model at startup, not per-request.
- Validate every input with Pydantic — wrong types crash sklearn.
- Pin the columns and order —
X = pd.DataFrame([features.dict()])may shuffle on Python <3.7. - Log every prediction (input + output + latency) for monitoring.
- Add a
/healthendpoint that returns 200 only if the model is loaded. - Pin every dependency —
scikit-learn,pandas,pydantic. Mismatches break pickled models. - Save the model with the training environment documented. A model trained with sklearn 1.4 may fail to load in 1.6.
- Containerize it — see next chapter.
- Rate-limit the endpoint — prediction calls can be expensive.
- Monitor drift — track input distribution & accuracy over time.
Common pitfalls¶
- ❗ Loading the model inside the endpoint — slow first request, repeated for every call. Load once at startup.
- ❗ Different sklearn versions between training and serving — silent prediction errors or load failures. Pin both.
- ❗ Wrong column order — sklearn doesn't always check; it can silently use age as BMI. Build DataFrames explicitly.
- ❗ Unbounded prediction time — some models are slow. Add a timeout or use background tasks for long inference.
- ❗ Using pickle — fragile across Python/sklearn versions. For long-term storage, use
joblib, ONNX, or a model serving framework (BentoML, TorchServe, Triton).
What's next¶
Practice¶
What does this print?
Expected: True
Load the model ONCE at startup, not on every request
Expected: True
Quiz — Quick check¶
What you remember
Q1. When should you load the ML model?
- Once at app startup (via lifespan event or module-level load), reuse across all requests
- On each request
- When the user logs in
- Lazily on first prediction
Why: Loading a model from disk takes seconds. Doing it per request adds that latency to every call. Load once at startup, hold in memory, reuse.
Q2. What's the right way to serialize a scikit-learn model for production?
-
joblib.dump(model, "model.joblib")— efficient for numpy-heavy objects -
pickle.dump - JSON
- Pandas to_csv
Why:
joblibis similar to pickle but more efficient for objects containing large numpy arrays (like trained models). Pickle works too. Avoid loading pickle from untrusted sources — code execution risk.
Q3. How should you handle a slow ML inference call in a FastAPI endpoint?
- Run it in a thread pool with
run_in_executorto avoid blocking the event loop - Always async
- In a subprocess
- Block the request
Why: ML inference is usually CPU-bound (not I/O). Wrapping in
run_in_executor(or using a sync handler that FastAPI runs in its thread pool automatically) keeps the event loop free for other requests.
Common doubts¶
Where should I store the trained model file?
Local disk works for small models. For larger models or multi-instance deploys, use object storage (S3, GCS, R2) — download at startup. For really large models, mount an EFS/NFS share. Version your model files (e.g., model-v2.joblib).
How do I version a deployed model?
Tag the model file (model-v3.joblib) and the API version. Return the model version in the response. Maintain backward compatibility — old client code shouldn't break when you ship a new model. MLOps tools like MLflow/Weights & Biases formalize this.
What about batching predictions for throughput?
Implement a batch endpoint that accepts a list of inputs and returns a list of predictions. Internally, model.predict on the whole batch is much faster than N calls to model.predict per item — especially for GPU-bound deep models.