
Serving ML models in production requires optimizing latency, throughput, and reliability. This guide covers deployment architectures and optimization techniques.
Deploy PyTorch models with production-grade serving:
# model_handler.py - Custom TorchServe handler
import torch
import torch.nn.functional as F
from ts.torch_handler.base_handler import BaseHandler
import json
import logging
class CustomModelHandler(BaseHandler):
"""
Production-ready model handler with monitoring and safety checks
"""
def __init__(self):
super().__init__()
self.initialized = False
self.latency_threshold_ms = 100 # ⚠️ SLA requirement
self.error_count = 0
self.error_threshold = 10
def initialize(self, context):
"""Load model and preprocessing"""
self.manifest = context.manifest
properties = context.system_properties
model_dir = properties.get("model_dir")
# Load model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = f"{model_dir}/model.pt"
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
# Load preprocessing config
with open(f"{model_dir}/config.json", 'r') as f:
self.config = json.load(f)
self.initialized = True
logging.info(f"Model loaded successfully on {self.device}")
def preprocess(self, data):
"""Preprocess input data"""
import time
start = time.time()
# Extract input from request
inputs = []
for row in data:
input_data = row.get("data") or row.get("body")
# Validate input
if input_data is None:
raise ValueError("No input data provided")
# Convert to tensor
if isinstance(input_data, str):
input_data = json.loads(input_data)
tensor = torch.tensor(input_data, dtype=torch.float32)
inputs.append(tensor)
batch = torch.stack(inputs).to(self.device)
preprocess_time = (time.time() - start) * 1000
if preprocess_time > 10:
logging.warning(f"Slow preprocessing: {preprocess_time:.2f}ms")
return batch
def inference(self, batch):
"""Run model inference with timing"""
import time
start = time.time()
with torch.no_grad():
outputs = self.model(batch)
inference_time = (time.time() - start) * 1000
# ⚠️ Monitor latency SLA
if inference_time > self.latency_threshold_ms:
logging.error(f"SLA violation: {inference_time:.2f}ms > {self.latency_threshold_ms}ms")
self.error_count += 1
if self.error_count > self.error_threshold:
logging.critical("Error threshold exceeded, may need to scale out")
logging.info(f"Inference latency: {inference_time:.2f}ms")
return outputs
def postprocess(self, outputs):
"""Convert model outputs to response format"""
# Get predictions
probabilities = F.softmax(outputs, dim=1)
predictions = outputs.argmax(dim=1)
# Format response
results = []
for pred, probs in zip(predictions, probabilities):
results.append({
'prediction': pred.item(),
'confidence': probs[pred].item(),
'probabilities': probs.tolist()
})
return results
# Deploy model with TorchServe
# 1. Archive model
# torch-model-archiver --model-name my_model \
# --version 1.0 \
# --model-file model.py \
# --serialized-file model.pt \
# --handler model_handler.py \
# --export-path model_store/
# 2. Start TorchServe
# torchserve --start --model-store model_store --models my_model=my_model.mar
Click to examine closelyfrom torch.quantization import quantize_dynamic, quantize_static, prepare, convert
Quantization and TensorRT optimization:
import torch
from torch.quantization import quantize_dynamic, quantize_static, prepare, convert
import torch.nn as nn
class OptimizedModel:
"""Optimize model for production inference"""
def __init__(self, model):
self.model = model
def dynamic_quantization(self):
"""Post-training dynamic quantization (easy, CPU-friendly)"""
# Convert linear layers to int8
quantized_model = quantize_dynamic(
self.model,
{nn.Linear}, # Quantize linear layers
dtype=torch.qint8
)
# Measure speedup
original_size = self._get_model_size(self.model)
quantized_size = self._get_model_size(quantized_model)
print(f"Model size: {original_size:.2f}MB → {quantized_size:.2f}MB")
print(f"Compression: {original_size / quantized_size:.2f}x")
return quantized_model
def static_quantization(self, calibration_loader):
"""Post-training static quantization (more accurate)"""
# Fuse layers (conv + bn + relu)
self.model.eval()
model_fused = torch.quantization.fuse_modules(
self.model,
[['conv', 'bn', 'relu']] # Specify fusion patterns
)
# Prepare for quantization
model_fused.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = prepare(model_fused)
# Calibrate with representative data
with torch.no_grad():
for data, _ in calibration_loader:
model_prepared(data)
# Convert to quantized model
model_quantized = convert(model_prepared)
return model_quantized
def to_torchscript(self):
"""Convert to TorchScript for deployment"""
self.model.eval()
# Trace model
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(self.model, example_input)
# Optimize for inference
traced_model = torch.jit.optimize_for_inference(traced_model)
return traced_model
def to_onnx(self, output_path="model.onnx"):
"""Export to ONNX for cross-platform deployment"""
example_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
self.model,
example_input,
output_path,
export_params=True,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"Exported to {output_path}")
def _get_model_size(self, model):
"""Calculate model size in MB"""
import io
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
return buffer.tell() / 1e6
# Example usage
import torchvision.models as models
model = models.resnet18(pretrained=True)
optimizer = OptimizedModel(model)
quantized = optimizer.dynamic_quantization()
scripted = optimizer.to_torchscript()
optimizer.to_onnx("resnet18.onnx")
Click to examine closely
Optimize throughput with batching:
import asyncio
from collections import deque
from typing import List
import time
class BatchInferenceServer:
"""Dynamically batch requests for GPU efficiency"""
def __init__(self, model, max_batch_size=32, max_latency_ms=50):
self.model = model
self.max_batch_size = max_batch_size
self.max_latency_ms = max_latency_ms / 1000 # Convert to seconds
self.request_queue = deque()
self.running = False
async def predict(self, input_data):
"""Async prediction request"""
# Create future for this request
future = asyncio.Future()
# Add to queue with timestamp
self.request_queue.append({
'input': input_data,
'future': future,
'timestamp': time.time()
})
# Wait for result
return await future
async def batch_processor(self):
"""Process requests in batches"""
while self.running:
# Wait for requests
while len(self.request_queue) == 0:
await asyncio.sleep(0.001)
# Collect batch
batch = []
futures = []
# Batch until max_batch_size or max_latency
start_time = time.time()
while len(batch) < self.max_batch_size and self.request_queue:
request = self.request_queue.popleft()
batch.append(request['input'])
futures.append(request['future'])
# Check latency deadline
oldest_request_age = time.time() - request['timestamp']
if oldest_request_age > self.max_latency_ms:
break
if batch:
# Run batch inference
batch_tensor = torch.stack(batch)
with torch.no_grad():
outputs = self.model(batch_tensor)
# Distribute results
for future, output in zip(futures, outputs):
future.set_result(output)
batch_time = (time.time() - start_time) * 1000
print(f"Processed batch of {len(batch)}, latency: {batch_time:.2f}ms")
def start(self):
"""Start batch processing"""
self.running = True
asyncio.create_task(self.batch_processor())
def stop(self):
"""Stop batch processing"""
self.running = False
# Usage
async def main():
model = models.resnet18().eval()
server = BatchInferenceServer(model, max_batch_size=32, max_latency_ms=50)
server.start()
# Simulate concurrent requests
tasks = []
for i in range(100):
input_tensor = torch.randn(3, 224, 224)
task = server.predict(input_tensor)
tasks.append(task)
results = await asyncio.gather(*tasks)
print(f"Processed {len(results)} requests")
server.stop()
# asyncio.run(main())
Click to examine closelySafely roll out new models:
import random
class ModelRouter:
"""Route traffic between model versions"""
def __init__(self):
self.models = {} # version -> model
self.traffic_split = {} # version -> percentage
self.metrics = {} # version -> {latency, accuracy, errors}
def register_model(self, version: str, model, traffic_pct: float = 0.0):
"""Register new model version"""
self.models[version] = model
self.traffic_split[version] = traffic_pct
self.metrics[version] = {
'requests': 0,
'latency_ms': [],
'errors': 0
}
print(f"Registered model {version} with {traffic_pct}% traffic")
def predict(self, input_data):
"""Route request to model version based on split"""
# Select model version
rand = random.random() * 100
cumulative = 0
selected_version = None
for version, pct in self.traffic_split.items():
cumulative += pct
if rand < cumulative:
selected_version = version
break
if selected_version is None:
selected_version = list(self.models.keys())[0]
# Run inference
model = self.models[selected_version]
start = time.time()
try:
result = model(input_data)
latency = (time.time() - start) * 1000
# Record metrics
self.metrics[selected_version]['requests'] += 1
self.metrics[selected_version]['latency_ms'].append(latency)
return result, selected_version
except Exception as e:
self.metrics[selected_version]['errors'] += 1
raise e
def get_metrics(self):
"""Compare model versions"""
for version, metrics in self.metrics.items():
if metrics['requests'] > 0:
avg_latency = sum(metrics['latency_ms']) / len(metrics['latency_ms'])
error_rate = metrics['errors'] / metrics['requests']
print(f"\n{version}:")
print(f" Requests: {metrics['requests']}")
print(f" Avg latency: {avg_latency:.2f}ms")
print(f" Error rate: {error_rate:.2%}")
def canary_rollout(self, new_version: str, steps=5):
"""Gradually increase traffic to new version"""
old_version = [v for v in self.models.keys() if v != new_version][0]
for step in range(1, steps + 1):
new_pct = (step / steps) * 100
old_pct = 100 - new_pct
self.traffic_split[new_version] = new_pct
self.traffic_split[old_version] = old_pct
print(f"Step {step}/{steps}: {old_version}={old_pct}%, {new_version}={new_pct}%")
# In production: monitor metrics, rollback if issues detected
# if error_rate_increase > threshold:
# self.rollback(new_version)
# Example
router = ModelRouter()
router.register_model("v1.0", model_v1, traffic_pct=100)
router.register_model("v1.1", model_v2, traffic_pct=0)
# Canary deployment: gradually shift traffic
router.canary_rollout("v1.1", steps=5)
Click to examine closelyCascading Failures: When one model service fails, traffic redirects to healthy instances, potentially overloading them. The 2033 "Inference Cascade" took down global recommendation systems.
Model Staleness: Production models degrade as data distributions shift. Monitor performance continuously.
Resource Exhaustion: Memory leaks in inference servers accumulate slowly. The 2035 "OOM Pandemic" crashed services worldwide after weeks of operation.
Related Chronicles: The Inference Apocalypse (2033) - Cascading ML service failures
Tools: TorchServe, TensorFlow Serving, NVIDIA Triton, BentoML, Seldon Core
Research: Model compression, neural architecture search for efficient models