On this page
- Shrinking a Medical AI Model Without Losing Diagnostic Accuracy
- What You Need Before Diving Into This
- How the Pipeline Fits Together
- Stage 1: Setting Up the Quantization Configuration
- Stage 2: Calibrating on Representative Medical Data
- Stage 3: Converting to int8 and Exporting to ONNX
- Stage 4: Proving the Quantized Model Didn’t Lose Accuracy
- The Adaptive Model Selector
- Three Quantization Pitfalls That Will Hurt You in Clinical Deployment
Shrinking a Medical AI Model Without Losing Diagnostic Accuracy
A radiologist’s workflow can’t wait five seconds for model inference. In a busy emergency department, that latency gap between detecting a brain bleed on a CT scan and delivering a result matters. But the same model that runs beautifully on a GPU-accelerated workstation in a major hospital needs to run on a CPU-only server in a rural clinic — where nobody’s budget stretches to NVIDIA hardware.
The VitalCheck API solves this with two techniques applied together: model quantization (converting 32-bit float weights to 8-bit integers) and format standardization (exporting from PyTorch to ONNX). The result is a 4–6x model size reduction, 2–3x inference speedup, and hardware-agnostic deployment — without sacrificing the diagnostic precision that clinical settings require.
This tutorial walks through the quantization and deployment pipeline in app/ml/imaging.py and scripts/train_image_model.py.
What You Need Before Diving Into This
Knowledge:
- Neural network basics: forward passes, weight tensors, what
model.eval()does versusmodel.train() - PyTorch:
torch.nn.Module,state_dict, how to load and run a model - Floating-point representation: why
0.1 + 0.2 != 0.3in IEEE 754 - Healthcare ML context: models in clinical settings need to hit much higher reliability bars than consumer apps — the FDA cares, and false negatives in tumor detection have real consequences
Dependencies (from pyproject.toml):
Python 3.10+
PyTorch 2.0.1+
ONNX 1.14.0+
ONNXRuntime 1.15.0+
FastAPI 0.100.0+
NumPy 1.24+
How the Pipeline Fits Together
I think of quantization like pharmaceutical dilution. Full-precision training produces medicine at full concentration — expensive to manufacture, stable. Quantization dilutes it to a clinically effective concentration — cheaper to deploy, faster to administer, same clinical outcome if done correctly. Calibration is the trial that proves the diluted version works. ONNX export standardizes the format so any hospital (Intel, ARM, GPU) can use it. The fallback model is the original formula on the shelf, in case the diluted version shows an unexpected reaction.
The data flow in the pipeline:
graph TD
A["Raw Medical Images<br/>Brain MRI DICOM"] -->|Preprocessing| B["Normalized Tensors<br/>224x224x3 RGB"]
B -->|Full-Precision Model<br/>32-bit float| C["Baseline Logits<br/>Confidence Scores"]
D["Training Dataset<br/>Calibration Subset"] -->|Quantization-Aware Training| E["QAT Model<br/>Simulated int8"]
E -->|ONNX Export| F["ONNX Format<br/>Quantized Weights"]
C -->|Label Encoding| G["Reference Predictions<br/>4-class brain tumors"]
F -->|ONNXRuntime Inference| H["Quantized Predictions<br/>int8 computation"]
G -->|Comparison| I{"Accuracy ≥<br/>Baseline - 1%?"}
H -->|Comparison| I
I -->|✓ Approved| J["model_int8.onnx<br/>4.2x Compressed"]
I -->|✗ Rejected| K["Fallback to<br/>model_fp32.onnx"]
J -->|Runtime Selection| L["FastAPI Endpoint<br/>Choose best format"]
K -->|Runtime Selection| L
L -->|REST Request| M["Inference<br/>2-3x Faster"]
M -->|JSON Response| N["Radiologist UI<br/>Tumor Class + Confidence"]
We keep both quantized (int8) and full-precision (fp32) ONNX models deployed at all times. The endpoint selects based on hardware, and falls back automatically if the quantized model fails or produces unexpected results.
Stage 1: Setting Up the Quantization Configuration
Quantization maps 32-bit floats to 8-bit integers. A weight of 0.3847 (4 bytes) becomes 42 (1 byte). This works because neural network weights cluster around small ranges, and the relative magnitude between weights matters more than absolute precision.
For medical models, I use a more conservative configuration than the defaults:
import torch
import torch.quantization as tq
model.qconfig = tq.QConfig(
activation=tq.HistogramObserver.with_args(reduce_range=False),
weight=tq.PerChannelMinMaxObserver.with_args(dtype=torch.qint8)
)
HistogramObserver calibrates activation ranges by analyzing the full distribution across calibration samples — not just min/max. This matters for medical networks because outlier activations from high-contrast tumor boundaries can throw off a simpler range estimate.
PerChannelMinMaxObserver treats each output channel separately. Different convolutional filters have different weight distributions. Per-channel quantization gives each filter its own scale factor, which preserves accuracy that per-layer quantization would lose.
reduce_range=False means we use the full [-128, 127] range instead of [-64, 63]. For healthcare models, we can’t afford to throw away bit precision.
Stage 2: Calibrating on Representative Medical Data
Quantization requires a calibration pass — not training, just observation. The model runs on representative images while PyTorch records the range of activations at each layer.
model.eval()
model_fp32 = model.to('cpu')
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_fp32, inplace=True)
with torch.no_grad():
for images, _ in tqdm(calibration_loader, desc="Calibrating quantization"):
output = model_fp32(images)
The calibration dataset matters enormously here. If you calibrate only on high-quality, high-contrast scans, the quantization will optimize for those. Real-world MRIs include motion artifacts, scanner noise, and subtle early-stage tumors. A missed tumor due to a miscalibrated model is a clinical failure, not just a metric regression.
I build a balanced calibration set that deliberately includes all four difficulty levels: obvious tumors, subtle boundaries, artifact-heavy scans, and normal controls. About 100–500 images total. If your calibration set is biased, your quantized model will perform differently than your test metrics suggest.
Stage 3: Converting to int8 and Exporting to ONNX
After calibration, converting the model is a single call:
model_int8 = torch.quantization.convert(model_fp32, inplace=False)
The weights are now int8 internally, but PyTorch’s int8 inference is actually slower on GPUs than float32 — the GPU hardware is optimized for float operations. That’s why we export to ONNX. ONNXRuntime has dedicated int8 execution paths that map to CPU SIMD instructions (AVX2/AVX-512 on Intel, NEON on ARM).
dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32)
torch.onnx.export(
model_int8,
dummy_input,
"models/brain_tumor_4class_int8.onnx",
input_names=["image"],
output_names=["logits"],
opset_version=14, # Version 14 has robust int8 support
do_constant_folding=True,
dynamic_axes={'image': {0: 'batch_size'}}
)
# Export full-precision version as fallback
torch.onnx.export(
model_fp32, dummy_input, "models/brain_tumor_4class.onnx",
input_names=["image"], output_names=["logits"],
opset_version=14, dynamic_axes={'image': {0: 'batch_size'}}
)
ONNX export traces one forward pass through the model, records every operation in the computation graph, and serializes everything to a protobuf file. At inference time, ONNXRuntime uses this static graph — no Python runtime, no GIL, no dynamic graph construction. That’s where the speedup comes from.
Modern CPUs have dedicated int8 execution units. Intel’s VNNI instructions can multiply four pairs of int8 values per clock cycle, versus the full IEEE arithmetic float32 requires. ONNXRuntime also fuses operations — conv → relu → pooling becomes a single kernel with no intermediate buffer allocations. The net result is typically 3–4x faster than PyTorch float32 on CPU.
Stage 4: Proving the Quantized Model Didn’t Lose Accuracy
For consumer ML, dropping from 95% to 94% accuracy is often acceptable. In clinical ML, that percentage point difference can mean missing 1% of brain tumors. You have to prove the quantized model meets the same clinical threshold as the full-precision model.
session_int8 = ort.InferenceSession("models/brain_tumor_4class_int8.onnx",
providers=['CPUExecutionProvider'])
session_fp32 = ort.InferenceSession("models/brain_tumor_4class.onnx",
providers=['CPUExecutionProvider'])
predictions_int8, predictions_fp32, ground_truth = [], [], []
with torch.no_grad():
for images, labels in tqdm(validation_loader, desc="Validating quantization"):
image_np = images.numpy().astype(np.float32)
logits_int8 = session_int8.run(None, {"image": image_np})[0]
logits_fp32 = session_fp32.run(None, {"image": image_np})[0]
predictions_int8.append(np.argmax(logits_int8, axis=1)[0])
predictions_fp32.append(np.argmax(logits_fp32, axis=1)[0])
ground_truth.append(labels.numpy()[0])
accuracy_int8 = accuracy_score(ground_truth, predictions_int8)
accuracy_fp32 = accuracy_score(ground_truth, predictions_fp32)
if accuracy_fp32 - accuracy_int8 <= 0.01:
print("APPROVED: Quantized model meets clinical accuracy requirements")
else:
print("REJECTED: Accuracy drop exceeds 1%. Use full-precision model only.")
I validate on held-out test data, not the calibration set. And I check precision/recall/F1, not just accuracy. For brain tumor classification, high recall (don’t miss tumors) matters more than high precision (don’t flag healthy tissue). Both thresholds need to be met.
The Adaptive Model Selector
In a hospital network, different devices have different capabilities. An ARM edge device in an exam room might only fit the int8 model in its 2 GB RAM. A radiology server with 64 GB can run the full fp32 model. The inference endpoint needs to handle this without manual configuration per device.
class AdaptiveModelSelector:
def __init__(self, model_int8_path: str, model_fp32_path: str):
self.model_int8_path = model_int8_path
self.model_fp32_path = model_fp32_path
self.session_int8 = None
self.session_fp32 = None
self.selected_model = None
def initialize(self, prefer_quantized: bool = True) -> str:
try:
self.session_int8 = ort.InferenceSession(
self.model_int8_path,
providers=['CPUExecutionProvider', 'CUDAExecutionProvider']
)
self.selected_model = 'int8'
except Exception as e:
print(f"Failed to load int8 model: {e}")
try:
self.session_fp32 = ort.InferenceSession(
self.model_fp32_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
if self.selected_model is None:
self.selected_model = 'fp32'
except Exception as e:
raise RuntimeError("No inference models available!")
return self.selected_model
def infer(self, image_np: np.ndarray, validate: bool = False):
try:
if self.selected_model == 'int8' and self.session_int8:
logits = self.session_int8.run(None, {"image": image_np})[0]
if validate and self.session_fp32:
logits_fp32 = self.session_fp32.run(None, {"image": image_np})[0]
if np.argmax(logits[0]) != np.argmax(logits_fp32[0]):
logits = logits_fp32
self.selected_model = 'fp32'
return logits, 'int8'
elif self.selected_model == 'fp32' and self.session_fp32:
logits = self.session_fp32.run(None, {"image": image_np})[0]
return logits, 'fp32'
except Exception as e:
print(f"Inference error with {self.selected_model}: {e}")
# Try the other model as fallback
if self.selected_model == 'int8' and self.session_fp32:
logits = self.session_fp32.run(None, {"image": image_np})[0]
self.selected_model = 'fp32'
return logits, 'fp32 (fallback)'
raise RuntimeError("Complete inference failure — no models available")
The validate=True flag is worth highlighting. In production, you can run both models on the same input and compare predictions. If they disagree, fall back to fp32 and log the discrepancy. A disagreement rate above 5% is a signal that something is wrong with your int8 model — miscalibration, hardware issues, or a batch size mismatch that’s changing activation statistics.
Three Quantization Pitfalls That Will Hurt You in Clinical Deployment
Post-training quantization without QAT. If you train a model normally then quantize it after the fact (PTQ), the model never learned to operate with quantized activations. The result is a 3–5% accuracy drop for sensitive architectures. Quantization-Aware Training (QAT) inserts fake quantization operations during training itself, so the model learns to compensate. Use torch.quantization.prepare_qat(model) before your training loop, train normally, then convert. The extra training time pays off in preserved accuracy.
Biased calibration data. If you calibrate only on clean, high-contrast scans, the quantization scale factors will be tuned for easy cases. Subtle tumors and artifact-heavy scans will fall outside the calibrated range, causing clipping errors. Build your calibration set to include the same distribution of difficulty levels as your validation and real-world deployment data.
Batch size sensitivity with dynamic axes. If your model was exported with dynamic_axes but ONNXRuntime produces different predictions depending on batch size, batch normalization is the usual culprit. Test your ONNX model explicitly with batch sizes of 1, 2, and 4 on the same input — predictions should be identical regardless of batch size. If they’re not, you have a normalization layer that’s computing running statistics differently across batch sizes.