TF32 matmul-precision audit — float32 physics paths#

Date: 2026-05-21 Branch: audit/tf32-matmul-precision Trigger: a GPU matmul-precision bug (TF32) was found silently corrupting the D3Q19 LBM by up to 100%+ (see couette_torque_mass_conservation.md). TF32 is a process-wide GPU default, so this audit checks every other float32 matmul-based physics path in MIME.

The TF32 trap#

On Ampere+ NVIDIA GPUs, JAX’s default float32 matmul precision is TF32 — a ~10-bit mantissa (~3 decimal digits, ~1e-3 relative error). Any float32 matmul dispatched to GPU tensor cores silently runs at TF32 unless precision="highest" is passed.

It is catastrophic only when the matmul result is a near-cancellation — a residual much smaller in magnitude than the input terms (relative condition number ≫ 1). TF32’s ~1e-3 relative error on the inputs then swamps the result. It is harmless (~0.1% noise, tolerable) when the output is the same order of magnitude as the inputs.

Matmuls with all dimensions ≲ 16–32 (3×3 rotations, 4×4 transforms, 6×6 resistance matrices, 3-vector dots) are never dispatched to tensor cores — XLA does them inline in full float32 — so TF32 cannot reach them.

The audit was run in genuine float32 (jax_enable_x64=False); each suspect path was run at default precision vs precision="highest" and compared.

Verdicts#

Path

Verdict

Evidence

FVM dense-DCT pressure / Helmholtz (fvm/pressure.py)

AFFECTED — FIXED

0.45% error vs highest

Twin-LBM step (lbm/pallas_lbm.py)

AFFECTED — FIXED

identical f@e moment bug as the D3Q19 core

Stokeslet free-space BEM resistance

not sensitive

bit-identical (no tensor-core matmul)

Stokeslet nearest-neighbour resistance

not sensitive

≤1e-4 rel. (well-conditioned K@P)

Stokeslet cylinder-Green’s-function resistance

not sensitive

bit-identical

Kinematics / robot / actuation / FVM operators

not sensitive

all small or contract over dim ≤3

Affected paths (fixed)#

1. FVM dense-DCT pressure / Helmholtz solver — fvm/pressure.py#

make_pressure_solver / make_helmholtz_solver apply the DCT/DST/real-DFT as dense matmuls (jnp.tensordot, in _apply_dct_along_axis) — the transform_backend="auto" default for meshes ≲256³, i.e. essentially all production runs (FVMFluidNode builds a float32 mesh).

The transform matrices are orthonormal (condition number 1), so this is not the catastrophic class — but the solver divides by the Laplacian eigenvalues 1/λ (dynamic range ~N²), which mildly amplifies TF32 noise across the 6 chained forward+inverse transforms.

Test (make_pressure_solver, smooth zero-mean rhs, genuine float32):

grid

‖p_default p_highest‖ / ‖p‖

32³ neumann

4.6e-3

64³ neumann

4.5e-3

128³ neumann

4.2e-3

64³ periodic

3.8e-3

A direct spectral linear solver returning a ~0.45% error is a genuine degradation. The dense transform is not the PISO bottleneck (the report’s own docstring notes this), so precision="highest" costs essentially nothing.

Fix: precision="highest" on the jnp.tensordot in _apply_dct_along_axis — the single chokepoint feeding both solvers.

2. Twin-LBM step — lbm/pallas_lbm.py#

Despite the filename this is plain JAX, not a Pallas/Triton kernel. It is used by DefectCorrectionFluidNode’s twin-LBM path and was explicitly left out of the D3Q19 LBM fix.

momentum = f @ e (line 125) is the identical operation to the D3Q19 moment bug — an LBM momentum, a ~1e-5 result summed from ~0.05-magnitude populations. Verified directly on this path: for a flow with u_x 3e-5, f @ e returns 0.0 at default precision and the correct 2.9996e-5 at precision="highest" — TF32 destroys the moment entirely. u @ e.T (equilibrium) and force @ e.T (Guo forcing) are the same moment matmuls fixed in the D3Q19 core.

Fix: precision="highest" on all three moment matmuls — the same fix the D3Q19 core received.

Checked, not sensitive#

Stokeslet BEM / resistance (stokeslet/)#

Resistance matrices computed at default vs highest precision (genuine float32, separate processes so jit-caching cannot mask the difference):

method

‖R_default R_highest‖ / ‖R‖

free-space BEM (compute_resistance_matrix)

0 (bit-identical)

free-space NN (compute_nn_resistance_matrix)

3e-6

confined NN (compute_nn_confined_resistance_matrix)

1e-4

cylinder-Green’s-fn (compute_gcyl_confined_resistance_matrix)

0 (bit-identical)

Why the Stokeslet is safe:

  • assemble_system_matrix builds the regularised-Stokeslet matrix with a double vmap over stokeslet_tensorelement-wise, not a matmul.

  • solve_bem / solve_bem_multi_rhs use jax.scipy.linalg lu_factor/lu_solvecuSOLVER, not tensor-core GEMM.

  • The NN method’s A = K @ P is a large matmul, but P is a nearest-neighbour selection matrix, so each A[i,j] is a sum of similar-sign Stokeslet-kernel values weighted by positive quadrature weights — no near-cancellation, well-conditioned.

  • The cylinder Green’s-function correction is assembled in NumPy on the CPU (no TF32); its only JAX-side matmuls are batched 3×3 Green’s-tensor × rotation products (too small for tensor cores).

  • bem.py K @ u_n and flow_field.py S @ f are 3×3 matvecs.

Everything else#

A full triage of every jnp.matmul / jnp.dot / jnp.einsum / jnp.tensordot / @ site in control/kinematics/, nodes/robot/, nodes/actuation/, core/, and the rest of nodes/environment/fvm/ (PISO, operators, IBM, SDF, lifting, simple) found no dangerous sites:

  • Kinematics / robot dynamics / quaternion / transforms — all 3×3, 4×4, 6×6, or N×N with N = joint count (≲30) — never tensor-core dispatched.

  • FVM einsums ("fd,fd->f", "...i,i->...", etc.) all contract over the spatial dimension (2 or 3) — far below the tensor-core threshold; XLA lowers them as batched multiply-reduce in full float32, not a GEMM. (The Rhie–Chow grad_p_along result is residual-like, but the einsum that produces it contracts over dim ≤3 and the cancellation itself is plain float32 subtraction — not a matmul.)

Out of scope (correctly skipped)#

  • D3Q19 / D2Q9 LBM core (d3q19.py, d2q9.py, bounce_back.py) — already fixed in prior work.

  • NN surrogates (fvm/gnn.py, surrogates/cholesky_mlp.py, the MLP in mlp_resistance_node.py) — neural-net inference is TF32-tolerant by design; forcing full precision would only slow them.

Fix pattern#

Per-call precision="highest" on the affected matmul — never a global jax_default_matmul_precision flag (that would needlessly slow the TF32-tolerant surrogates). Each fixed site carries a comment explaining why TF32 corrupts that specific matmul.

Test status#

Existing tests stay green. The FVM verification tests run in float64 (where precision="highest" is a no-op — TF32 is float32-only), so they are unchanged. pallas_lbm.py has no numerical baseline test. No re-baselining was required.