The static_data / StaticArray channel#
MIME environment nodes carry large, non-evolving arrays — unstructured FVM
meshes, BEM LU factors, lattice wall masks, MLP weights. Before MADDENING
v0.2 these were plain instance attributes closed over by update(), which
baked them into the compiled HLO as constants.
v0.2 added the static_data channel to hold them outside the JIT closure;
MIME adopted it across its environment nodes in the v0.2 fit-up (§3).
The problem it solves#
A node has two places to put a tensor:
initial_state()— the array joins the JAX state pytree and is threaded through everyscan/fori_loop/ multi-rate step / gradient pass, even though it never changes. It is also checkpointed every save.A plain
self._xattribute closed over byupdate()— JAX bakes it into the compiled HLO as a constant. Correct numerically, but the constant is re-embedded each time the step function is traced, and a 1 GB FVM mesh baked into HLO bloats compile time and memory.
static_data is the third option: a declared channel for arrays that
participate in update() but do not evolve. The GraphManager tracks it for
JIT-cache invalidation, keeps it out of the state pytree, and keeps it out
of checkpoints.
The StaticArray wrapper#
Array values in static_data must be wrapped in
maddening.core.static_data.StaticArray — a frozen dataclass that carries
the array alongside its sharding policy:
StaticArray(value, replication="replicate", shard_axis=None)
value— a NumPy or JAX array (anything with.shapeand.dtype). Built once at construction; held by reference, not copied. Do not mutate after wrapping.replication—"replicate"(default): every device gets the full array."shard": each device gets a slice alongshard_axis.shard_axis— required whenreplication="shard", and must beNoneotherwise. The GraphManager slicesvaluealong this axis at sharding time.
StaticArray is strict by construction. __post_init__ raises:
TypeErrorifvalueis not array-like, or is itself alist/tuple/dict/set. Nested structures are unsupported — a pytree of arrays cannot be wrapped whole; unfold it into multiple top-level keys.ValueErrorifreplicationis unknown, ifshard_axisis missing for a sharded array or set for a replicated one, or ifshard_axisis out of range for the array’s shape.
A bare array left in static_data is, for now, coerced to
StaticArray(value=arr) with a FutureWarning. That coercion path is
removed in v0.3 — bare arrays will raise. Wrap them now.
Scalars, strings, and tuples do not need wrapping: they carry no sharding decision and stay bare in the dict.
The static_data property contract#
SimulationNode.static_data is a property; the default returns {}.
Override it on nodes that need the channel.
Keys are strings.
Values are bare Python scalars / strings / tuples, or
StaticArray-wrapped arrays.The returned dict must be stable across calls for a given node instance. Build it once in
__init__and stash it onself; never reconstruct it per call. A property that rebuilds the dict — and the arrays inside it — every access defeats the channel.
MIME’s FVM node is the reference adoption: __init__ calls a
_build_static_data() helper once and stores the result, and the property
is a pure getter.
# src/mime/nodes/environment/fvm/fluid_node.py
def __init__(self, ...):
...
# Built once here; the property below just returns it.
self._static_data = self._build_static_data()
def _build_static_data(self) -> dict:
"""Unfold the FVMMesh pytree into flat static_data keys.
StaticArray rejects nested structures, so the mesh — a frozen
pytree of arrays plus a tuple of BoundaryPatch — cannot be
wrapped whole. Each leaf array becomes its own top-level key.
"""
from maddening.core.static_data import StaticArray
m = self._mesh
sd = {
"mesh_owner": StaticArray(m.owner),
"mesh_neighbour": StaticArray(m.neighbour),
"mesh_Sf": StaticArray(m.Sf),
# ... mesh_n / mesh_area / mesh_d / mesh_w / mesh_V / mesh_x
# Bare scalar / tuple metadata — hashed by repr().
"N_cells": m.N_cells,
"dim": m.dim,
"cartesian_shape": m.cartesian_shape,
}
for p in m.patches:
sd[f"patch_{p.name}_owner"] = StaticArray(p.owner)
# ... per-patch Sf / n / area / d / face_x
return sd
@property
def static_data(self) -> dict:
return self._static_data
Every FVM array uses replication="replicate": the face graph mixes
face-indexed (owner, Sf, …) and cell-indexed (V, x) arrays on
different axes, so no single shard_axis is valid — and the node’s
halo_width() returns {0: 1}, which already blocks MADDENING’s pointwise
sharder from sharding it.
static_data_hash()#
SimulationNode.static_data_hash() produces a stable hash over
static_data, used as part of the JIT cache key so a geometry change
triggers a recompile.
For each array value it hashes (key, shape, dtype, replication, shard_axis) — not the array contents. static_data is expected to be
far larger than the state, so content hashing would be prohibitive; the
assumption is that an array’s identity is captured by its shape, dtype, and
sharding policy. Bare scalars hash by repr(value). An empty static_data
hashes to 0.
Because sharding policy is in the key, a node that switches an array from
"replicate" to "shard" is correctly recognised as a recompile-worthy
change. Conversely, swapping in a different array of the same shape and
dtype will not invalidate the cache — if a node must change array
contents, it must also change shape/dtype or be reconstructed.
Checkpoint caveat#
static_data is not checkpointed. The .npz manifest carries no record
of it — not the arrays, not the replication / shard_axis metadata.
The consequence for node authors: the configuration needed to rebuild the
static arrays must live in self.params, which is checkpointed. On a
checkpoint/restore round-trip the node is reconstructed from its persisted
params, and __init__ rebuilds static_data from them.
MIME’s LBM node shows the pattern: the pipe-wall occupancy mask and its
D3Q19 missing-link mask are derived in __init__ from nx / ny / nz /
vessel_radius_lu, all of which are passed up to super().__init__() as
params. A resumed node recomputes the masks from the persisted params — the
masks themselves never need to survive the checkpoint.
# src/mime/nodes/environment/lbm/fluid_node.py
@property
def static_data(self) -> dict:
"""Non-evolving lattice masks closed over by update().
Both masks are rebuilt from self.params (nx/ny/nz/
vessel_radius_lu) in __init__; a checkpoint/restore round-trip
reconstructs them from the persisted params, since static_data
itself is not checkpointed.
"""
from maddening.core.static_data import StaticArray
return {
"pipe_wall": StaticArray(self._pipe_wall),
"pipe_missing": StaticArray(self._pipe_missing),
}
Per-step closures#
A node’s update() is JIT-compiled once — by the GraphManager, as a unit —
not afresh on every simulation step. A make_*_step() factory called
inside update() therefore runs a single time, during that one trace, and
the static arrays it closes over are captured once.
The §3 fit-up initially planned to memoise FVM’s make_piso_step closure on
self, on the theory that rebuilding it per call would re-bake the mesh.
Investigation showed the concern does not apply: make_piso_step returns a
plain (un-jit’d) closure and is cheap, so FVMFluidNode deliberately keeps
the call inside update() — the factory runs once per trace, not per step.
The one pitfall to avoid: never call jax.jit inside update(). That
builds a fresh jitted function — and a fresh compilation — on every call,
which does defeat static_data. Leave the JIT boundary to the
GraphManager.
MIME nodes that adopted the channel#
Node |
What moved to |
|---|---|
|
The |
IBLBM |
Pipe-wall occupancy mask ( |
|
BEM LU factors ( |
|
MLP weight tensors and the input/output normalization statistics. |
|
LU factors and pipe masks. |
The GNN flux corrector is not applicable — it is a standalone pytree
utility with no host SimulationNode, so it has no static_data property
to populate.
This page documents the §3 deliverable of the MIME v0.2 fit-up. Ground
truth: maddening.core.static_data (the StaticArray dataclass) and
SimulationNode.static_data / static_data_hash() in
maddening.core.node.