strategies:
- strategy: Convert high-level code to hardware-specific kernel code
- strategy: Convert a small amount of high-level code to hardware-specific kernel code
- strategy: Decompose matrix multiplications into blocked grid loops with `pl.BlockSpec` index maps and accumulate in a `pltpu.VMEM`
    scratch buffer using `jnp.dot(..., preferred_element_type=jnp.float32)`, initializing and flushing the accumulator with
    `pl.when` on the reduction axis program id.
- strategy: "Replace element-wise operations on large arrays with pipelined `pl.pallas_call` kernels that tile inputs into\
    \ (8, 128)-aligned blocks via `pl.BlockSpec`, using `pltpu.PrefetchScalarGridSpec` with appropriate `dimension_semantics`\
    \ to overlap HBM\u2194VMEM transfers with VPU compute."
- strategy: Convert reduction operations (e.g. `jnp.sum`, `jnp.max` over an axis) into Pallas kernels where the reduction
    axis is the innermost grid dimension, output refs are initialized on the first iteration via `pl.when(pl.program_id(axis)
    == 0)`, and accumulated in-place across consecutive iterations that share the same output block.
- strategy: Map gather/scatter or dynamic-index operations to `pltpu.PrefetchScalarGridSpec` with scalar prefetch arguments
    loaded into SMEM, using the prefetched index refs inside `BlockSpec.index_map` to perform data-dependent block selection
    at each grid iteration.
- strategy: Fuse post-matmul activations, transposes, or casts directly into the matmul kernel body by applying them to the
    accumulator on the final reduction step (`pl.when(pl.program_id(k_axis) == num_k_steps - 1)`), avoiding a separate memory-bound
    elementwise kernel pass.
