GPU precision and TF32#
This is an awareness note. If you only run MIME’s built-in experiments, no action is needed — the relevant internal paths already force full precision. Read on if you write your own GPU code.
What TF32 is#
On Ampere and newer NVIDIA GPUs, JAX’s default float32 matmul precision is TF32 — a reduced format with a ~10-bit mantissa (~3 decimal digits, ~1e-3 relative error). Any float32 matmul dispatched to GPU tensor cores silently runs at TF32 unless you ask for full precision explicitly.
Why it can corrupt physics#
TF32 is harmless when a matmul’s output is the same order of magnitude as
its inputs — it just adds ~0.1% noise. It is catastrophic when the
result is a near-cancellation: a residual far smaller than the input terms
(an LBM momentum moment summed from much larger populations, a spectral
pressure residual). There, TF32’s ~1e-3 input error swamps the answer
entirely — a moment that should be 3e-5 can come back as exactly 0.0.
What MIME already does#
MIME’s v0.2 fit-up forced full precision on the affected internal paths —
the LBM moment transforms (D3Q19 / D2Q9 moment matrices) and the FVM
pressure solver. These fixes are per-call precision="highest" on the
specific matmuls; TF32-tolerant paths (neural-net surrogates) are left
fast. Nothing is required of you when using these solvers.
If you write your own GPU code#
For precision-sensitive operations — anything where the result is much smaller than its inputs — request full precision explicitly:
import jax.numpy as jnp
y = jnp.matmul(a, b, precision="highest") # also jnp.dot / jnp.einsum / jnp.tensordot
Prefer per-call precision= over the global
jax_default_matmul_precision flag, so TF32-tolerant code stays fast.
Small matmuls (all dimensions ≲ 16–32 — 3×3 rotations, 4×4 transforms)
are never dispatched to tensor cores, so TF32 cannot reach them.
Further reading#
The full investigation — every audited float32 matmul path in MIME, with
verdicts and measured errors — is in
tf32_matmul_precision_audit.md.