Chapter 28: ML Compilers and Runtime
Part VII: AI/HPC
"A compiler is a program that translates a program written in one language into a program written in another language." — Alfred Aho
The 50x Speedup That Came from Nowhere
Raj was benchmarking different inference backends for his computer vision model. The model was straightforward—a ResNet-50 for image classification. He expected small differences between backends, maybe 10-20%.
The first run with vanilla PyTorch: 15ms per image.
With TorchScript: 12ms. A modest 20% improvement.
Then he tried TensorRT: 0.3ms.
Raj stared at the numbers. Fifty times faster? He ran it again. Same result. He checked the accuracy—identical to within floating-point tolerance.
"How is this possible?" he asked his colleague Ming, who had experience with ML compilers. "It's the same model. Same GPU. Same input."
Ming smiled. "You just discovered why ML compilers exist. PyTorch is designed for flexibility and debugging. It executes operations one at a time, with Python overhead between each one. TensorRT analyzes the entire graph, fuses operations together, chooses optimal kernel implementations, and preallocates all memory. The math is identical, but the execution is completely different."
"But why doesn't everyone just use TensorRT then?"
"Trade-offs. TensorRT compilation can take 20 minutes. It doesn't support every operation. It's harder to debug. And if your model changes frequently, recompiling is painful. ML compilers are powerful, but they're not free."
This chapter explores the world of ML compilers—how they achieve dramatic speedups, what trade-offs they involve, and how to benchmark systems that use them.
Why ML Compilers Exist
When you train a model in PyTorch or TensorFlow, the framework prioritizes flexibility. Each operation is dispatched independently. Gradient computation is tracked automatically. You can stop, inspect, and modify execution at any point.
This flexibility is essential for research but disastrous for production inference. Every layer of abstraction adds overhead. Every dynamic dispatch adds latency. Every flexibility feature you're not using is still costing you.
ML compilers bridge this gap: they take a high-level model description and produce optimized code for specific hardware, eliminating the flexibility overhead in exchange for performance.
The Complexity They Hide
Consider what's required to run a simple convolutional neural network efficiently:
- Operation Fusion: A Conv → BatchNorm → ReLU sequence should be executed as a single fused kernel, not three separate operations
- Memory Layout: Should the tensor be stored as NCHW or NHWC? Different hardware prefers different layouts
- Precision Selection: Which layers can use FP16? Which need FP32? Where should quantization happen?
- Kernel Selection: For a 3×3 convolution with batch size 32, which of the 47 available kernel implementations is fastest?
- Memory Planning: How should intermediate activations be allocated to minimize fragmentation?
Multiply this by dozens of hardware targets, hundreds of possible operators, and millions of possible configurations. No human can manually optimize this. ML compilers make it tractable.
What ML Compilers Do
Input: High-level model description (PyTorch, TensorFlow, ONNX)
↓
┌─────────────────────────────────────────────────────────┐
│ Frontend │
│ - Parse model │
│ - Build computation graph │
│ - Type inference │
└─────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────┐
│ Graph Optimization │
│ - Operator fusion │
│ - Constant folding │
│ - Dead code elimination │
│ - Layout transformation │
└─────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────┐
│ Backend │
│ - Hardware-specific optimization │
│ - Memory planning │
│ - Code generation │
└─────────────────────────────────────────────────────────┘
↓
Output: Optimized executable program
TVM (Apache TVM)
TVM is one of the most well-known open-source ML Compilers.
TVM Architecture
┌─────────────────────────────────────────────────────────┐
│ Relay (High-level IR) │
│ - Functional IR │
│ - Supports dynamic shapes │
│ - Graph-level optimization │
├─────────────────────────────────────────────────────────┤
│ TIR (Tensor IR) │
│ - Low-level IR │
│ - Loop representation │
│ - Hardware mapping │
├─────────────────────────────────────────────────────────┤
│ Runtime │
│ - Cross-platform execution │
│ - Memory management │
│ - Device abstraction │
└─────────────────────────────────────────────────────────┘
Using TVM
import tvm
from tvm import relay
import onnx
# 1. Load ONNX model
onnx_model = onnx.load("model.onnx")
# 2. Convert to Relay IR
mod, params = relay.frontend.from_onnx(onnx_model)
# 3. Set target hardware
target = tvm.target.Target("cuda")
# 4. Compile
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
# 5. Execute
dev = tvm.cuda(0)
module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
module.set_input("input", input_data)
module.run()
output = module.get_output(0)
ONNX Runtime
ONNX Runtime is a cross-platform inference engine developed by Microsoft.
ONNX Runtime Features
Advantages:
- Wide hardware support
- Mature and stable
- Easy to integrate
- Supports multiple Execution Providers
Execution Providers:
- CPU (default)
- CUDA
- TensorRT
- DirectML
- OpenVINO
- CoreML
- NNAPI
Using ONNX Runtime
import onnxruntime as ort
import numpy as np
# Create session
session = ort.InferenceSession(
"model.onnx",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
# Prepare input
## XLA (Accelerated Linear Algebra)
XLA is an ML Compiler developed by Google, primarily used for TensorFlow and JAX.
### XLA Features
```text
Design goals:
- Automatically optimize TensorFlow/JAX programs
- Support TPU
- JIT and AOT compilation
Main optimizations:
- Operator fusion
- Memory optimization
- Parallelization
Using XLA in TensorFlow
import tensorflow as tf
# Method 1: Use jit_compile
@tf.function(jit_compile=True)
def model_fn(x):
return tf.nn.relu(tf.matmul(x, w) + b)
# Method 2: Enable globally
tf.config.optimizer.set_jit(True)
Using XLA in JAX
import jax
import jax.numpy as jnp
# JAX uses XLA by default
@jax.jit
def model_fn(x, w, b):
return jax.nn.relu(jnp.dot(x, w) + b)
# Execute (automatically compiled)
result = model_fn(x, w, b)
Performance Comparison
Benchmark Setup
Model: ResNet-50
Hardware: NVIDIA A100
Batch Size: 1, 8, 32
Precision: FP32, FP16
Typical Results
Framework/Compiler Batch=1 Batch=8 Batch=32
─────────────────────────────────────────────────────
PyTorch (eager) 5.2 ms 8.1 ms 18.5 ms
PyTorch (compile) 3.8 ms 5.2 ms 12.1 ms
ONNX Runtime (CUDA) 3.5 ms 4.8 ms 11.2 ms
TensorRT 2.1 ms 3.2 ms 7.8 ms
TVM (tuned) 2.4 ms 3.5 ms 8.5 ms
Note: Actual values depend on specific configuration
Selection Guide
Scenario Recommendation
─────────────────────────────────────────────────────
Rapid prototyping PyTorch eager
Production (NVIDIA) TensorRT
Cross-platform ONNX Runtime
Edge devices TVM, IREE
TPU XLA (JAX)
Research/experiments PyTorch compile
Common Optimization Techniques
Operator Fusion
Before optimization:
x = Conv(input)
y = BatchNorm(x)
z = ReLU(y)
3 memory read/writes
After optimization:
z = FusedConvBNReLU(input)
1 memory read/write
Effect: Reduces memory bandwidth requirements
Constant Folding
Before optimization:
a = Constant(2)
b = Constant(3)
c = Add(a, b)
y = Mul(x, c)
After optimization:
y = Mul(x, 5)
Effect: Reduces runtime computation
Layout Transformation
Different hardware prefers different data layouts:
CPU: NCHW (batch, channel, height, width)
GPU: NCHW or NHWC
TPU: NHWC
ML Compilers automatically insert necessary transformations
and try to minimize the number of conversions
Memory Planning
Problem:
Intermediate results need memory
How to minimize total memory usage?
Solution:
Analyze tensor lifetimes
Reuse memory that's no longer needed
Similar to compiler register allocation
The Compiler That Saved the Project
Remember Aisha's edge deployment problem? After two weeks of manual optimization, she was still 40% short of her latency target.
Then she tried TVM with auto-tuning. She let it run overnight on a representative workload.
The next morning, she had a model that met her latency target with room to spare. The auto-tuner had found optimizations she never would have discovered manually—unusual tile sizes, unexpected operator fusion patterns, memory layouts that seemed counterintuitive but worked perfectly for her specific hardware.
"I spent two weeks doing what the compiler did in eight hours," she admitted. "And it did it better."
But she also learned the limits. When she tried to deploy the same model on a slightly different chip variant, the auto-tuned schedule performed poorly. She had to re-tune for the new target.
"ML compilers aren't magic," she concluded. "They're tools that trade tuning time for performance. For production deployment on known hardware, they're invaluable. For rapid prototyping across many targets, they might slow you down."
The lesson: ML compilers represent a fundamental shift in how we think about optimization—from hand-crafted expertise to automated search. But like any tool, knowing when to use them is as important as knowing how.
Summary
ML Compilers are key technology for modern AI deployment:
Main Tools
- TVM: Open source, auto-tuning, cross-platform
- IREE: Lightweight, suitable for edge devices
- ONNX Runtime: Mature, stable, easy to integrate
- XLA: Backend for TensorFlow/JAX
Core Optimizations
- Operator fusion
- Constant folding
- Layout transformation
- Memory planning
Selection Considerations
- Target hardware
- Performance requirements
- Development efficiency
- Maintenance cost
Performance Analysis
- Use each framework's profiling tools
- Compare performance across compilers
- Consider tuning time vs performance gain