--- orphan: false --- # Sharded static data and domain integrals (v0.2.1) ```{versionadded} v0.2.1 Per-device materialisation of `StaticArray(replication="shard")` under {class}`~maddening.cloud.multigpu.sharded_node.ShardedStencilNode`, and the {meth}`~maddening.core.node.SimulationNode.domain_integral_fields` API for cross-device reduction of partial-sum outputs. ``` This page explains how to write a stencil node whose non-evolving arrays are *sharded* alongside state — needed when the array is big enough that replicating it per device would defeat the point of sharding — and how to declare outputs that are domain integrals over the lattice so the wrapper all-reduces them. The motivating consumer is MIME's `IBLBMFluidNode`: a D3Q19 LBM node carrying a `(nx, ny, nz)` pipe-wall mask and a `(19, nx, ny, nz)` "missing-link" bounce-back mask, both far too large to replicate on every device when the simulation is sharded across a pencil mesh. ## The picture ```{mermaid} flowchart LR subgraph "ShardedStencilNode.update()" A[state pytree
and StaticArray] --> B[device_put
NamedSharding per shard_axis] B --> C[shard_map] C --> D[halo_exchange
state slab + static slab] D --> E[inner.update_padded
static_padded=...
shard_info=...] E --> F[strip halos
or psum integrals] F --> G[next state] end ``` Three things happen per step: 1. The wrapper takes each `StaticArray(replication="shard")` on the inner node and `device_put`s it with a `NamedSharding` whose `PartitionSpec` puts the matching mesh axis on the array's `shard_axis` (cached by `static_data_hash`). 2. Inside `shard_map`, each device's slab is halo-exchanged along the matching spatial axis (boundary `"edge"` — static arrays don't evolve, so periodic wrap is wrong even if state uses periodic). 3. The inner node's `update_padded` receives the padded slab via `static_padded[]` and uses it like any other padded array. ## Declaring a sharded static array The declaration lives on the node: ```python from maddening.core.node import SimulationNode from maddening.core.static_data import StaticArray class WallBouncebackLBM(SimulationNode): def __init__(self, name, timestep, *, nx, ny, nz, pipe_radius): super().__init__(name, timestep, nx=nx, ny=ny, nz=nz, pipe_radius=pipe_radius) self._pipe_wall = _build_pipe_mask(nx, ny, nz, pipe_radius) def halo_width(self): # D3Q19: read ±1 neighbour on every spatial axis return {0: 1, 1: 1, 2: 1} @property def static_data(self): return { "pipe_wall": StaticArray( self._pipe_wall, replication="shard", shard_axis=0, # shard along the x axis ), } ``` `shard_axis` is the array's *own* axis — it must match one of the spatial axes the wrapping `ShardedStencilNode` actually shards (`shard_axis ∈ axis_map.values()`) **and** the node must declare a non-zero `halo_width()` entry on that axis (otherwise there's no neighbour slab to exchange with). Both invariants are checked at `ShardedStencilNode.__init__` time: ```python sharded = ShardedStencilNode( WallBouncebackLBM("lbm", 0.001, nx=128, ny=64, nz=64, pipe_radius=20), mesh=mesh, axis_map={"spatial_x": 0}, # shard axis 0 on mesh "spatial_x" boundary="periodic", ) # Raises ValueError immediately if pipe_wall's shard_axis isn't in # axis_map.values() (i.e. it's not on the sharded spatial axis), or # if halo_width() has no entry for shard_axis. ``` Replicate-mode `StaticArray` entries are not affected — they stay closure-captured in full on every device, exactly as in v0.2.0. ## Reading the sharded slab in `update_padded` The inner node's `update_padded` gets a new keyword-only argument: ```python def update_padded( self, state_padded, boundary_inputs, dt, *, static_padded=None, shard_info=None, ) -> dict: """Receives a per-shard slab of pipe_wall via `static_padded`.""" f_pad = state_padded["f"] # (nx_local+2, ny, nz, 19) wall_pad = static_padded["pipe_wall"] # (nx_local+2, ny, nz) # interior view (halo stripped) is wall_pad[1:-1, :, :] # cells outside the interior come from the neighbour shard's # boundary cells via halo_exchange ... ``` The keyword is `None` when there are no sharded static_data entries on the node, or when the node is run outside of `ShardedStencilNode` (e.g. single-device path). Default to the closure-captured full array in that case: ```python def update_padded( self, state_padded, boundary_inputs, dt, *, static_padded=None, shard_info=None, ): if static_padded is not None and "pipe_wall" in static_padded: wall = static_padded["pipe_wall"] else: # Unsharded path; the full wall is closure-captured on self. wall = jnp.pad(self._pipe_wall, [(1, 1), (0, 0), (0, 0)], mode="edge") ... ``` ## Declaring domain-integral outputs A node that computes a `jnp.sum`-over-lattice output (a drag force, a total mass, a heat-flux integral) needs cross-device reduction under sharding. Declare which output keys are integrals: ```python def domain_integral_fields(self) -> set[str]: return {"drag_force", "drag_torque"} ``` and have `update_padded` compute the partial sum on the local slab — no `psum` in the node: ```python def update_padded(self, state_padded, boundary_inputs, dt, *, static_padded=None, shard_info=None): ... force_partial = jnp.sum( per_cell_force * wall_pad[1:-1, :, :, None], axis=(0, 1, 2), ) return { "f": f_new, "drag_force": force_partial, # (3,) — partial sum on this shard } ``` ShardedStencilNode applies `lax.psum(value, axis_name=tuple(mesh.axis_names))` to every key listed in `domain_integral_fields()` after `update_padded` returns, and sets the corresponding `out_spec` to `P()` (fully replicated post-psum). Returned integral values are floating-point — `psum` on integer types risks wrap on large meshes. ```{important} Every key returned from `update_padded` must be either in `state_fields()` or in `domain_integral_fields()`. The wrapper defensively raises `ValueError` at trace time on an unknown key — it has no way to infer the partition spec for an unclassified output. ``` ## `shard_info`: when you'd rather recompute the mask Not every static array needs to be materialised. For analytic masks (a sphere, a cylinder, a coordinate range), it's often cheaper to recompute per shard than to ship the full mask through `jax.device_put` once and `halo_exchange` it every step. ShardedStencilNode populates a `shard_info` dict for the inner node: ```python def update_padded(self, state_padded, boundary_inputs, dt, *, static_padded=None, shard_info=None): # shard_info = {0: (global_offset, local_extent)} # # global_offset is a TRACED jax scalar: # lax.axis_index(mesh_axis) * local_extent # local_extent is a Python int. if shard_info is not None and 0 in shard_info: offset, extent = shard_info[0] # offset usable in dynamic_slice; NOT in Python int slicing global_x = offset + jnp.arange(extent + 2) # +2 for halos ... ``` `shard_info` is `None` when the node is run outside of `ShardedStencilNode`. ## Sharding policy is part of the JIT cache key `SimulationNode.static_data_hash()` already incorporates `(key, shape, dtype, replication, shard_axis)`. Switching a StaticArray from `replication="replicate"` to `replication="shard"` between graph compiles invalidates the JIT trace — no extra machinery needed in user code. The shard_map cache inside `ShardedStencilNode` also keys on `static_data_hash()`, so *identity* changes (a user replacing `self._mask` between steps) invalidate cleanly even when shape and dtype are unchanged. ## What's still TODO * **Partial-axis psum.** `domain_integral_fields()` triggers a full-mesh `psum` over every mesh axis. Reductions over a subset of axes (e.g. for outputs that vary along one axis but integrate along the others) are out of scope for v0.2.1 — open an issue if you need it. * **Sharded static arrays without halos.** v0.2.1 rejects a sharded `StaticArray` whose `shard_axis` doesn't appear in `node.halo_width()`. A future relaxation could allow such arrays by skipping the halo exchange, but the semantics get subtle near slab boundaries. ## See also * {doc}`/release_notes/v0.2.1` — release notes for the cycle this shipped in. * [`StaticArray` API reference](https://github.com/Microrobotics-Simulation-Framework/MADDENING/blob/main/src/maddening/core/static_data.py) — dataclass contract and hash semantics. * {doc}`edge_validation_migration` — companion v0.2.1 change.