general:
- The rewritten program should be semantically equivalent to the original program, within a small numerical tolerance.
- Keep the same function name and signature as the original program (helper functions can be renamed or deleted).
- Pallas kernel functions receive `Ref` objects, not arrays. A bare `Ref` cannot be used in arithmetic — it must be explicitly
  read via indexing (e.g., `x_ref[...]` or `x_ref[:]`) to produce an array, and results must be written back to output `Ref`s
  via indexed assignment (e.g., `o_ref[...] = value`). The kernel must return `None`.
- On TPU, block shapes must have rank ≥ 1 (no scalar/0D blocks), and the last two dimensions must each either equal the full
  array dimension or be divisible by 8 (second-to-last) and 128 (last) respectively. Violating this causes compiler errors.
- Elementwise operations on TPU only support 32-bit types natively; narrower-precision operands (bf16, int8, etc.) must be
  upcast to a 32-bit type before elementwise computation. Matrix multiplication always accumulates in float32, and f32 inputs
  are silently rounded to bf16 unless explicit precision is requested.
- Integer reductions (sum, max, min) are not supported on TPU; these reductions work only on floating-point values. Boolean
  reductions (any/all) work only on booleans.
- Python `for` / `range` loops over traced values will fail to compile. Use `jax.lax.fori_loop` / `jax.lax.while_loop` when
  the bound depends on a traced value.
- The `index_map` in `BlockSpec` must accept exactly as many arguments as `grid` dimensions and return exactly as many block
  indices as the rank of the corresponding array (excluding `None`/squeezed dims). In blocked indexing mode, `index_map` returns
  block indices (not element indices); the element start is computed as `block_index * block_size`.
planning:
- Limit the scope of the plan to the selected strategy.
- Do not count out any of the strategies unless they are clearly irrelevant to the code.
- All live tiles (input buffers, output buffers, scratch/accumulators, plus compiler-spilled registers) must fit simultaneously
  in VMEM (~16 MB). Exceeding VMEM causes compilation to fail.
- Grid dimensions and block sizes must evenly divide the corresponding array dimensions. When blocks don't evenly divide,
  out-of-bounds elements are padded with garbage on input and discarded on output. For reductions across grid iterations,
  the reduction axis must be the innermost (last) grid dimension so the accumulator persists in VMEM across consecutive writes
  to the same output slice.
- Output buffers contain garbage initially and must be explicitly initialized (e.g., zeroed on the first iteration using `pl.when(pl.program_id(axis)
  == 0)`). When multiple invocations write to the same output elements, all such invocations must be consecutive in lexicographic
  grid order.
- When using `dimension_semantics`, it must be a tuple matching `grid` length. `"parallel"` axes must precede `"arbitrary"`
  axes — arbitrary cannot precede parallel. A dimension marked `"parallel"` must have truly independent iterations; marking
  a dependent dimension as parallel produces incorrect results.
- 'Reshapes modifying the last two dimensions are only supported for: (1) flattening leading dims onto the second-to-last
  dim, or (2) restoring a dim removed by reduction. Transpositions of non-trailing axes require ≥4 dimensions. All other reshape/transpose
  patterns on the last two dims cause compiler errors.'
- 'Inside kernels, `jnp.max`/`jnp.sum` along the last axis can trigger "Unsupported implicit dim change" if tile
  dims are not multiples of 8×128.'
- 'Inside kernels, `jnp.arange` only supports int32. Use int32 dtype and cast to float afterward.'
- 'Stride indexing (`x[:, ::2]`) is unsupported inside kernels. Use reshape + slice instead:
  `x.reshape(*x.shape[:-1], x.shape[-1]//2, 2)[..., 0]`.'
- 'Ref dtype must match the written value. Cast explicitly: `out_ref[...] = acc.astype(out_ref.dtype)`.'
coding:
- Wrap the generated code with ```python at the beginning and ``` at the end.
- '`pallas_call` requires `out_shape` (a `jax.ShapeDtypeStruct` or list thereof) to define output shapes/dtypes. When a `grid`
  is provided, `in_specs` and `out_specs` must be provided as matching-length lists of `BlockSpec`. The PyTree structure of
  `in_specs` must match input arguments, and `out_specs` must match `out_shape`.'
- On TPU, scalar values (e.g., from `program_id`) used as outputs must be placed in SMEM via `BlockSpec(memory_space=pltpu.SMEM)`.
  A buffer in `pl.ANY` (HBM) memory space cannot be dereferenced with array indexing — data must first be copied to VMEM or
  SMEM via `pltpu.sync_copy` or `pltpu.async_copy`.
- 'When using `PrefetchScalarGridSpec` with `num_scalar_prefetch=n`, no `BlockSpec`s are specified for the first n arguments
  (they go to SMEM automatically). The kernel signature must order arguments as: prefetch refs, then input refs, then output
  refs, then scratch refs. The `index_map` signature places prefetch refs after grid indices.'
- DMA operations via `pltpu.async_copy` require a DMA semaphore scratch (`pltpu.SemaphoreType.DMA`), and `.wait()` must be
  called on the returned copy object before accessing the destination data. Barrier semaphores from `get_barrier_semaphore()`
  must be waited back down to 0 after use or they will be corrupted.
- When using `input_output_aliases`, the alias key is the positional index in the combined argument list (including scalar
  prefetch args) and the value is the output index. The aliased input and output must have the same shape and dtype. When
  `should_accumulate_out` is set in `emit_pipeline`, the pipeline handles accumulation automatically — the caller must not
  manually accumulate into those same output refs.
