featured image

Trading Precision for Speed: Implementing Model Quantization and Format Conversion for Clinical Medical Imaging

In clinical environments, deploying deep learning models for medical imaging—brain tumors, strokes, cardiac anomalies—presents a unique constraint: the inference latency and memory footprint directly impact patient outcomes.

Published

Mon Oct 13 2025

Technologies Used

Python PyTorch ONNX scikit-learn
Advanced 41 minutes

Shrinking a Medical AI Model Without Losing Diagnostic Accuracy

A radiologist’s workflow can’t wait five seconds for model inference. In a busy emergency department, that latency gap between detecting a brain bleed on a CT scan and delivering a result matters. But the same model that runs beautifully on a GPU-accelerated workstation in a major hospital needs to run on a CPU-only server in a rural clinic — where nobody’s budget stretches to NVIDIA hardware.

The VitalCheck API solves this with two techniques applied together: model quantization (converting 32-bit float weights to 8-bit integers) and format standardization (exporting from PyTorch to ONNX). The result is a 4–6x model size reduction, 2–3x inference speedup, and hardware-agnostic deployment — without sacrificing the diagnostic precision that clinical settings require.

This tutorial walks through the quantization and deployment pipeline in app/ml/imaging.py and scripts/train_image_model.py.

What You Need Before Diving Into This

Knowledge:

  • Neural network basics: forward passes, weight tensors, what model.eval() does versus model.train()
  • PyTorch: torch.nn.Module, state_dict, how to load and run a model
  • Floating-point representation: why 0.1 + 0.2 != 0.3 in IEEE 754
  • Healthcare ML context: models in clinical settings need to hit much higher reliability bars than consumer apps — the FDA cares, and false negatives in tumor detection have real consequences

Dependencies (from pyproject.toml):

Python 3.10+
PyTorch 2.0.1+
ONNX 1.14.0+
ONNXRuntime 1.15.0+
FastAPI 0.100.0+
NumPy 1.24+

How the Pipeline Fits Together

I think of quantization like pharmaceutical dilution. Full-precision training produces medicine at full concentration — expensive to manufacture, stable. Quantization dilutes it to a clinically effective concentration — cheaper to deploy, faster to administer, same clinical outcome if done correctly. Calibration is the trial that proves the diluted version works. ONNX export standardizes the format so any hospital (Intel, ARM, GPU) can use it. The fallback model is the original formula on the shelf, in case the diluted version shows an unexpected reaction.

The data flow in the pipeline:

graph TD
    A["Raw Medical Images<br/>Brain MRI DICOM"] -->|Preprocessing| B["Normalized Tensors<br/>224x224x3 RGB"]
    B -->|Full-Precision Model<br/>32-bit float| C["Baseline Logits<br/>Confidence Scores"]
    
    D["Training Dataset<br/>Calibration Subset"] -->|Quantization-Aware Training| E["QAT Model<br/>Simulated int8"]
    E -->|ONNX Export| F["ONNX Format<br/>Quantized Weights"]
    
    C -->|Label Encoding| G["Reference Predictions<br/>4-class brain tumors"]
    F -->|ONNXRuntime Inference| H["Quantized Predictions<br/>int8 computation"]
    
    G -->|Comparison| I{"Accuracy ≥<br/>Baseline - 1%?"}
    H -->|Comparison| I
    
    I -->|✓ Approved| J["model_int8.onnx<br/>4.2x Compressed"]
    I -->|✗ Rejected| K["Fallback to<br/>model_fp32.onnx"]
    
    J -->|Runtime Selection| L["FastAPI Endpoint<br/>Choose best format"]
    K -->|Runtime Selection| L
    
    L -->|REST Request| M["Inference<br/>2-3x Faster"]
    M -->|JSON Response| N["Radiologist UI<br/>Tumor Class + Confidence"]

We keep both quantized (int8) and full-precision (fp32) ONNX models deployed at all times. The endpoint selects based on hardware, and falls back automatically if the quantized model fails or produces unexpected results.

Stage 1: Setting Up the Quantization Configuration

Quantization maps 32-bit floats to 8-bit integers. A weight of 0.3847 (4 bytes) becomes 42 (1 byte). This works because neural network weights cluster around small ranges, and the relative magnitude between weights matters more than absolute precision.

For medical models, I use a more conservative configuration than the defaults:

import torch
import torch.quantization as tq

model.qconfig = tq.QConfig(
    activation=tq.HistogramObserver.with_args(reduce_range=False),
    weight=tq.PerChannelMinMaxObserver.with_args(dtype=torch.qint8)
)

HistogramObserver calibrates activation ranges by analyzing the full distribution across calibration samples — not just min/max. This matters for medical networks because outlier activations from high-contrast tumor boundaries can throw off a simpler range estimate.

PerChannelMinMaxObserver treats each output channel separately. Different convolutional filters have different weight distributions. Per-channel quantization gives each filter its own scale factor, which preserves accuracy that per-layer quantization would lose.

reduce_range=False means we use the full [-128, 127] range instead of [-64, 63]. For healthcare models, we can’t afford to throw away bit precision.

Stage 2: Calibrating on Representative Medical Data

Quantization requires a calibration pass — not training, just observation. The model runs on representative images while PyTorch records the range of activations at each layer.

model.eval()
model_fp32 = model.to('cpu')
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_fp32, inplace=True)

with torch.no_grad():
    for images, _ in tqdm(calibration_loader, desc="Calibrating quantization"):
        output = model_fp32(images)

The calibration dataset matters enormously here. If you calibrate only on high-quality, high-contrast scans, the quantization will optimize for those. Real-world MRIs include motion artifacts, scanner noise, and subtle early-stage tumors. A missed tumor due to a miscalibrated model is a clinical failure, not just a metric regression.

I build a balanced calibration set that deliberately includes all four difficulty levels: obvious tumors, subtle boundaries, artifact-heavy scans, and normal controls. About 100–500 images total. If your calibration set is biased, your quantized model will perform differently than your test metrics suggest.

Stage 3: Converting to int8 and Exporting to ONNX

After calibration, converting the model is a single call:

model_int8 = torch.quantization.convert(model_fp32, inplace=False)

The weights are now int8 internally, but PyTorch’s int8 inference is actually slower on GPUs than float32 — the GPU hardware is optimized for float operations. That’s why we export to ONNX. ONNXRuntime has dedicated int8 execution paths that map to CPU SIMD instructions (AVX2/AVX-512 on Intel, NEON on ARM).

dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32)

torch.onnx.export(
    model_int8,
    dummy_input,
    "models/brain_tumor_4class_int8.onnx",
    input_names=["image"],
    output_names=["logits"],
    opset_version=14,  # Version 14 has robust int8 support
    do_constant_folding=True,
    dynamic_axes={'image': {0: 'batch_size'}}
)

# Export full-precision version as fallback
torch.onnx.export(
    model_fp32, dummy_input, "models/brain_tumor_4class.onnx",
    input_names=["image"], output_names=["logits"],
    opset_version=14, dynamic_axes={'image': {0: 'batch_size'}}
)

ONNX export traces one forward pass through the model, records every operation in the computation graph, and serializes everything to a protobuf file. At inference time, ONNXRuntime uses this static graph — no Python runtime, no GIL, no dynamic graph construction. That’s where the speedup comes from.

Modern CPUs have dedicated int8 execution units. Intel’s VNNI instructions can multiply four pairs of int8 values per clock cycle, versus the full IEEE arithmetic float32 requires. ONNXRuntime also fuses operations — conv → relu → pooling becomes a single kernel with no intermediate buffer allocations. The net result is typically 3–4x faster than PyTorch float32 on CPU.

Stage 4: Proving the Quantized Model Didn’t Lose Accuracy

For consumer ML, dropping from 95% to 94% accuracy is often acceptable. In clinical ML, that percentage point difference can mean missing 1% of brain tumors. You have to prove the quantized model meets the same clinical threshold as the full-precision model.

session_int8 = ort.InferenceSession("models/brain_tumor_4class_int8.onnx",
    providers=['CPUExecutionProvider'])
session_fp32 = ort.InferenceSession("models/brain_tumor_4class.onnx",
    providers=['CPUExecutionProvider'])

predictions_int8, predictions_fp32, ground_truth = [], [], []

with torch.no_grad():
    for images, labels in tqdm(validation_loader, desc="Validating quantization"):
        image_np = images.numpy().astype(np.float32)
        
        logits_int8 = session_int8.run(None, {"image": image_np})[0]
        logits_fp32 = session_fp32.run(None, {"image": image_np})[0]
        
        predictions_int8.append(np.argmax(logits_int8, axis=1)[0])
        predictions_fp32.append(np.argmax(logits_fp32, axis=1)[0])
        ground_truth.append(labels.numpy()[0])

accuracy_int8 = accuracy_score(ground_truth, predictions_int8)
accuracy_fp32 = accuracy_score(ground_truth, predictions_fp32)

if accuracy_fp32 - accuracy_int8 <= 0.01:
    print("APPROVED: Quantized model meets clinical accuracy requirements")
else:
    print("REJECTED: Accuracy drop exceeds 1%. Use full-precision model only.")

I validate on held-out test data, not the calibration set. And I check precision/recall/F1, not just accuracy. For brain tumor classification, high recall (don’t miss tumors) matters more than high precision (don’t flag healthy tissue). Both thresholds need to be met.

The Adaptive Model Selector

In a hospital network, different devices have different capabilities. An ARM edge device in an exam room might only fit the int8 model in its 2 GB RAM. A radiology server with 64 GB can run the full fp32 model. The inference endpoint needs to handle this without manual configuration per device.

class AdaptiveModelSelector:
    def __init__(self, model_int8_path: str, model_fp32_path: str):
        self.model_int8_path = model_int8_path
        self.model_fp32_path = model_fp32_path
        self.session_int8 = None
        self.session_fp32 = None
        self.selected_model = None
    
    def initialize(self, prefer_quantized: bool = True) -> str:
        try:
            self.session_int8 = ort.InferenceSession(
                self.model_int8_path,
                providers=['CPUExecutionProvider', 'CUDAExecutionProvider']
            )
            self.selected_model = 'int8'
        except Exception as e:
            print(f"Failed to load int8 model: {e}")
        
        try:
            self.session_fp32 = ort.InferenceSession(
                self.model_fp32_path,
                providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
            )
            if self.selected_model is None:
                self.selected_model = 'fp32'
        except Exception as e:
            raise RuntimeError("No inference models available!")
        
        return self.selected_model
    
    def infer(self, image_np: np.ndarray, validate: bool = False):
        try:
            if self.selected_model == 'int8' and self.session_int8:
                logits = self.session_int8.run(None, {"image": image_np})[0]
                
                if validate and self.session_fp32:
                    logits_fp32 = self.session_fp32.run(None, {"image": image_np})[0]
                    if np.argmax(logits[0]) != np.argmax(logits_fp32[0]):
                        logits = logits_fp32
                        self.selected_model = 'fp32'
                
                return logits, 'int8'
            
            elif self.selected_model == 'fp32' and self.session_fp32:
                logits = self.session_fp32.run(None, {"image": image_np})[0]
                return logits, 'fp32'
        
        except Exception as e:
            print(f"Inference error with {self.selected_model}: {e}")
        
        # Try the other model as fallback
        if self.selected_model == 'int8' and self.session_fp32:
            logits = self.session_fp32.run(None, {"image": image_np})[0]
            self.selected_model = 'fp32'
            return logits, 'fp32 (fallback)'
        
        raise RuntimeError("Complete inference failure — no models available")

The validate=True flag is worth highlighting. In production, you can run both models on the same input and compare predictions. If they disagree, fall back to fp32 and log the discrepancy. A disagreement rate above 5% is a signal that something is wrong with your int8 model — miscalibration, hardware issues, or a batch size mismatch that’s changing activation statistics.

Three Quantization Pitfalls That Will Hurt You in Clinical Deployment

Post-training quantization without QAT. If you train a model normally then quantize it after the fact (PTQ), the model never learned to operate with quantized activations. The result is a 3–5% accuracy drop for sensitive architectures. Quantization-Aware Training (QAT) inserts fake quantization operations during training itself, so the model learns to compensate. Use torch.quantization.prepare_qat(model) before your training loop, train normally, then convert. The extra training time pays off in preserved accuracy.

Biased calibration data. If you calibrate only on clean, high-contrast scans, the quantization scale factors will be tuned for easy cases. Subtle tumors and artifact-heavy scans will fall outside the calibrated range, causing clipping errors. Build your calibration set to include the same distribution of difficulty levels as your validation and real-world deployment data.

Batch size sensitivity with dynamic axes. If your model was exported with dynamic_axes but ONNXRuntime produces different predictions depending on batch size, batch normalization is the usual culprit. Test your ONNX model explicitly with batch sizes of 1, 2, and 4 on the same input — predictions should be identical regardless of batch size. If they’re not, you have a normalization layer that’s computing running statistics differently across batch sizes.

We respect your privacy.

← View All Tutorials

Related Projects

    Ask me anything!