general:
- The rewritten program should be semantically equivalent to the original program, within a small numerical tolerance.
- Keep the same function name and signature as the original program (helper functions can be renamed or deleted).
- Pallas kernel functions take `Ref` objects as inputs and return no values; all outputs must be written to output `Ref` parameters
  via indexed assignment (e.g., `o_ref[...] = value`). Reading a `Ref` requires explicit indexing (`x_ref[...]` or `x_ref[:]`);
  a bare `Ref` cannot be used in arithmetic.
- On TPU, block shapes must have rank ≥ 1 (no scalar/0D blocks). The last two dimensions of a block shape must each be divisible
  by 8 and 128 respectively, OR equal the corresponding full array dimension. For rank-1 blocks, the dimension must equal
  the array dimension, be a multiple of 1024, or be a power of 2 ≥ 128*(32/bitwidth).
- On TPU, elementwise operations only support 32-bit types natively; narrower-precision operands (e.g., bfloat16, int8) must
  be upcast to a 32-bit type before applying elementwise operations, then cast back if needed.
- The `index_map` function in a `BlockSpec` must accept exactly as many arguments as the length of the `grid` tuple and must
  return exactly as many block indices as the rank of the array (i.e., `len(block_shape)`). `in_specs` and `out_specs` must
  match the number and order of input arrays and output shapes respectively.
- Matrix multiplication (MXU) always accumulates in float32; even float32 operands are rounded to bfloat16 unless higher precision
  is explicitly requested via `jax.default_matmul_precision` or the `precision` parameter.
planning:
- Limit the scope of the plan to the selected strategy.
- Do not count out any of the strategies unless they are clearly irrelevant to the code.
- Block shapes must evenly divide the corresponding array dimensions for correct tiling. If they do not, out-of-bounds elements
  are padded with garbage on input and discarded on output — at least one element per block must be in bounds.
- Output `Ref` blocks that are accumulated across multiple grid iterations must be written on consecutive iterations in lexicographic
  grid order; the reduction axis must be the innermost (last) grid dimension. Initialize accumulators with `pl.when(pl.program_id(axis)
  == 0)` since output refs start uninitialized.
- Tensor dimensions used in matrix computations should be multiples of 128 to match the 128×128 MXU systolic arrays; sub-128
  dimensions waste compute due to zero-padding. Arrays with singleton dimensions in the last two axes are padded to 8×128
  tiles, causing up to 1024× register waste.
- VMEM capacity is limited (~16 MB); block sizes plus scratch buffers plus spilled registers must fit. Exceeding VMEM causes
  a low-level compiler OOM error. Plan block sizes accordingly.
- Loop primitives (`fori_loop`, `for_loop`) are fully unrolled at compile time; trip counts must be kept small to avoid compilation
  blowup. Control flow (`cond`, `when`) is supported but excessive branching degrades code generation.
coding:
- Wrap the generated code with ``` at the beginning and ``` at the end.
- On TPU, scalar values (e.g., from `program_id`) that need to be stored as outputs must be placed in SMEM via `pl.BlockSpec(memory_space=pltpu.SMEM)`.
  A buffer with `memory_space=pl.ANY` (HBM) cannot be dereferenced directly — data must first be copied to VMEM or SMEM.
- The kernel function must not close over JAX array constants; all needed arrays must be passed explicitly as inputs with
  proper `BlockSpec`s. Only `Ref` parameters and literals are valid inside the kernel body.
- 'When using `PrefetchScalarGridSpec` with `num_scalar_prefetch > 0`: (1) the kernel signature must order arguments as prefetch
  refs, then input refs, then output refs, then scratch refs; (2) `index_map` must accept `(*grid_indices, *prefetch_refs)`;
  (3) at call site, prefetch args come before input args.'
- '`BlockSpec` expects `block_shape` before `index_map` in its constructor. `compiler_params` must be a `pltpu.CompilerParams()`
  dataclass, not a dict. `pl.load`/`pl.store`/`pl.swap` and `pl.atomic_*` are removed — use indexing syntax (`ref[...] = val`,
  `val = ref[...]`) instead.'
- Reductions (`jnp.sum`, `jnp.max`, `jnp.min`) are only supported for floating-point values; boolean reductions use `jnp.any`/`jnp.all`;
  integer reductions are not supported. Reshapes that modify the last two dimensions are restricted to flattening leading
  dims onto the second-to-last or restoring a reduced dimension.
