optimizations:
- strategy: reduce data movement
- strategy: overlap data movement and compute
- strategy: cache reused data in local memory instead of reloading from main memory
- strategy: loop tiling
- strategy: loop reordering and restructuring
- strategy: loop unrolling
- strategy: fuse operations
- strategy: use lower precision
- strategy: double buffering
- strategy: software pipelining
- strategy: hoist redundant operations out of loops
- strategy: eliminate redundant computation
- strategy: simplify or remove unnecessary code
- strategy: try new parameter values
- strategy: rewrite the algorithm to reduce total work
- strategy: Place control-flow scalars and block indices in SMEM via `pltpu.SMEM` for fast scalar access
- strategy: Use `pltpu.emit_pipeline` for nested pipelines within a single kernel
- strategy: Align block shapes so last two dimensions are multiples of 8 and 128 respectively
- strategy: Avoid singleton dimensions in the last two axes to prevent wasteful tile padding
- strategy: Fuse RHS transpose into matmul via `jax.lax.dot_general` instead of separate transpose
- strategy: Use `preferred_element_type=jnp.float32` with bf16 inputs for native MXU accumulation
- strategy: Mark independent grid dimensions as `"parallel"` in `dimension_semantics` for dual-core Megacore execution
- strategy: Use hardware PRNG via `pltpu.prng_seed` and `pltpu.stateful_uniform` instead of software threefry
- strategy: Enable lookahead prefetch with `pl.Buffered(use_lookahead=True)` for variable-work blocks
- strategy: Increase buffer count beyond 2 with `pl.Buffered(buffer_count=N)` to hide memory latency
- strategy: Maximize block sizes to increase FLOPs-per-memory-transfer ratio toward compute-bound regime
- strategy: Use VMEM scratch buffers for f32 accumulation then downcast to bf16 on final iteration only
- strategy: Reduce grid iterations by using larger blocks to minimize pipeline bubble overhead
- strategy: Flatten leading dimensions onto second-to-last dimension for free reshapes avoiding costly last-axis reshapes
- strategy: Upcast narrow types to 32-bit before elementwise operations to match native hardware width
- strategy: Use `input_output_aliases` to enable in-place read-write on buffers without extra copies
- strategy: Perform reductions over second-to-last axis instead of last axis for faster hardware reduction
- strategy: Keep reduction dimension as innermost grid axis so output stays in SRAM during accumulation
- strategy: Use `pltpu.sync_copy` for explicit HBM-to-VMEM transfers when manual memory control is needed
- strategy: Precompute data-independent metadata with numpy before the kernel so it compiles to constants
- strategy: Skip or shrink iterations that produce zero-contribution results due to masking or structure
- strategy: Tune block sizes jointly — systematically vary BM, BN, BK together under the VMEM budget constraint (16 MiB including double-buffering and scratch), considering the tradeoff between fewer grid tiles (larger blocks → less pipeline overhead) and fitting in VMEM, and whether eliminating a K-reduction loop (BK = full K) with a smaller BN outperforms a K-tiled approach with larger BN
