JAX Pallas GPU Issues (for filing)#

Discovered while implementing a D3Q19 Lattice Boltzmann kernel. All issues tested on RunPod A40 (Ampere, CC 8.6) with CUDA 12.x.


Issue 1: reduce_sum with axis fails in multi-tile kernel#

Environment: JAX 0.9.2 (Mosaic GPU backend), A40, CUDA 12.8 Also reproduced on: JAX 0.5.3 (Triton backend) — different error, same outcome

Minimal reproducer:

import jax, jax.numpy as jnp
from jax.experimental import pallas as pl

N = 16; BX = BY = BZ = 8; Q = 32
x = jnp.ones((N, N, N, Q))

def sum_kernel(x_ref, o_ref):
    o_ref[...] = jnp.sum(x_ref[...], axis=-1)

o = pl.pallas_call(
    sum_kernel,
    out_shape=jax.ShapeDtypeStruct((N, N, N), jnp.float32),
    grid=(N // BX, N // BY, N // BZ),
    in_specs=[pl.BlockSpec((BX, BY, BZ, Q), lambda i, j, k: (i*BX, j*BY, k*BZ, 0))],
    out_specs=pl.BlockSpec((BX, BY, BZ), lambda i, j, k: (i*BX, j*BY, k*BZ)),
)(x)
# Expected: all elements = 32.0
# Actual (0.9.2): "No support for axes yet"
# Actual (0.5.3): First tile correct (32.0), subsequent tiles write zeros

Expected: o[i,j,k] = 32.0 for all i,j,k Actual (0.9.2): NotImplementedError: No support for axes yet Actual (0.5.3): o[0,0,0] = 32.0 but o[8,0,0] = 0.0 (second tile not written)

Use case: D3Q19 Lattice Boltzmann — needs sum over Q=19 (padded to 32) distributions per lattice node. This is the density computation ρ = Σ_q f_q, one of the most fundamental operations in computational fluid dynamics.

Workaround: Manual accumulation loop unrolled at trace time:

rho = f[..., 0]
for q in range(1, 32):
    rho = rho + f[..., q]

Issue 2: slice primitive not implemented in GPU lowering#

Environment: JAX 0.5.3 (Triton backend), A40, CUDA 12.8 Also fails on: JAX 0.9.2 (Mosaic backend)

Minimal reproducer:

import jax, jax.numpy as jnp
from jax.experimental import pallas as pl

def slice_kernel(x_ref, o_ref):
    x = x_ref[...]
    o_ref[...] = x[..., 0]  # extract first component

o = pl.pallas_call(
    slice_kernel,
    out_shape=jax.ShapeDtypeStruct((8, 8, 8), jnp.float32),
    grid=(1,),
    in_specs=[pl.BlockSpec((8, 8, 8, 4), lambda i: (0, 0, 0, 0))],
    out_specs=pl.BlockSpec((8, 8, 8), lambda i: (0, 0, 0)),
)(jnp.ones((8, 8, 8, 4)))

Expected: o[i,j,k] = 1.0 Actual: NotImplementedError: Unimplemented primitive in Pallas GPU lowering: slice

Use case: Extracting velocity components (x, y, z) from a force vector force[..., 0:3]. Essential for any multi-component physics kernel.

Workaround: Dot-product with one-hot mask vector passed as explicit input:

mask_x = jnp.array([1.0, 0.0, 0.0, 0.0])  # passed as kernel input
fx = jnp.sum(force * mask_x, axis=-1)

Issue 3: dot_general requires all dimensions ≥ 16#

Environment: JAX 0.5.3 (Triton backend), A40, CUDA 12.8 Also fails on: JAX 0.9.2 (Mosaic backend) — ValueError

Minimal reproducer:

import jax, jax.numpy as jnp
from jax.experimental import pallas as pl

def matmul_kernel(f_ref, e_ref, o_ref):
    o_ref[...] = f_ref[...] @ e_ref[...]

o = pl.pallas_call(
    matmul_kernel,
    out_shape=jax.ShapeDtypeStruct((8, 8, 8, 3), jnp.float32),
    grid=(1,),
    in_specs=[
        pl.BlockSpec((8, 8, 8, 19), lambda i: (0, 0, 0, 0)),
        pl.BlockSpec((19, 3), lambda i: (0, 0)),
    ],
    out_specs=pl.BlockSpec((8, 8, 8, 3), lambda i: (0, 0, 0, 0)),
)(jnp.ones((8, 8, 8, 19)), jnp.ones((19, 3)))

Expected: o[i,j,k] = [19, 19, 19] Actual (0.5.3): ValueError: all dimensions of b must be >= 16 Actual (0.9.2): ValueError (similar constraint)

Use case: Computing momentum p = f @ E where f is (nx,ny,nz,19) distributions and E is (19,3) velocity vectors. Standard in lattice Boltzmann, finite element, and particle methods.

Workaround: Element-wise multiply + manual accumulation:

px = f[...,0]*E[0,0]; py = f[...,0]*E[0,1]; pz = f[...,0]*E[0,2]
for q in range(1, 19):
    px = px + f[...,q]*E[q,0]
    py = py + f[...,q]*E[q,1]
    pz = pz + f[...,q]*E[q,2]

Issue 4: concatenate limited to 2-argument, [..., 1] shapes only#

Environment: JAX 0.5.3 (Triton backend), A40, CUDA 12.8 Also fails on: JAX 0.9.2 (Mosaic backend) — different error

Minimal reproducer:

import jax, jax.numpy as jnp
from jax.experimental import pallas as pl

a = jnp.ones((8, 8, 8, 1))

def concat_kernel(a_ref, b_ref, c_ref, o_ref):
    o_ref[...] = jnp.concatenate([a_ref[...], b_ref[...], c_ref[...]], axis=-1)

o = pl.pallas_call(
    concat_kernel,
    out_shape=jax.ShapeDtypeStruct((8, 8, 8, 3), jnp.float32),
    grid=(1,),
    in_specs=[pl.BlockSpec((8,8,8,1), lambda i: (0,0,0,0))] * 3,
    out_specs=pl.BlockSpec((8, 8, 8, 3), lambda i: (0, 0, 0, 0)),
)(a, a, a)

Expected: o.shape = (8, 8, 8, 3), all ones Actual (0.5.3): NotImplementedError: Only 2-argument concatenate is supported Actual (0.9.2): GMEM strides alignment error

Use case: Assembling velocity output u = stack([ux, uy, uz]) from per-component scalars. Common in any physics kernel that outputs vector fields.

Workaround: Use separate output refs for each component.


Issue 5: Non-power-of-2 array dimensions#

Environment: JAX 0.5.3 (Triton backend), A40, CUDA 12.8

Minimal reproducer:

import jax, jax.numpy as jnp
from jax.experimental import pallas as pl

def copy_kernel(f_ref, o_ref):
    o_ref[...] = f_ref[...]

o = pl.pallas_call(
    copy_kernel,
    out_shape=jax.ShapeDtypeStruct((16,16,16,19), jnp.float32),
    grid=(1,),
    in_specs=[pl.BlockSpec((16,16,16,19), lambda i: (0,0,0,0))],
    out_specs=pl.BlockSpec((16,16,16,19), lambda i: (0,0,0,0)),
)(jnp.ones((16,16,16,19)))

Expected: Identity copy Actual: ValueError: ...size is a power of 2. Encountered an array of shape (16, 16, 16, 19)

Note: JAX 0.9.2 (Mosaic backend) has a different constraint — shared memory size limit instead of power-of-2.

Use case: D3Q19 lattice Boltzmann has Q=19 velocity directions — a prime number. D3Q27 (Q=27) and D2Q9 (Q=9) are also non-power-of-2.

Workaround: Pad to next power of 2 (19→32). Wastes 40% of compute/memory.


Summary#

These five issues collectively prevent implementing a standard D3Q19 LBM kernel in Pallas GPU. The workarounds (manual loops, padding, separate outputs) add complexity but are functional. The manual loop approach (Issue 1 workaround) needs validation — see separate test.

Impact: Lattice Boltzmann is one of the most widely used computational fluid dynamics methods. Pallas GPU support for the operations above would enable high-performance LBM kernels that bypass XLA’s autotuning overhead (which causes 60+ min compilation on H100 for standard JAX LBM code).