optimizations:
- strategy: reduce data movement
- strategy: overlap data movement and compute
- strategy: cache reused data in local memory instead of reloading from main memory
- strategy: loop tiling
- strategy: loop reordering and restructuring
- strategy: loop unrolling
- strategy: fuse operations
- strategy: use lower precision
- strategy: double buffering
- strategy: software pipelining
- strategy: hoist redundant operations out of loops
- strategy: eliminate redundant computation
- strategy: simplify or remove unnecessary code
- strategy: try new parameter values
- strategy: rewrite the algorithm to reduce total work
- strategy: Place reduction axis last in grid to enable in-place SRAM accumulation without HBM round-trips
- strategy: Align block dimensions to 8×128 tile boundaries to avoid wasted padding and register spills
- strategy: Use `scratch_shapes=[pltpu.VMEM(...)]` for persistent high-precision accumulators during reduction loops
- strategy: Maximize block sizes up to ~16 MB VMEM capacity to increase arithmetic intensity per pipeline step
- strategy: Use scalar prefetch via `PrefetchScalarGridSpec` to load indices/metadata into SMEM without stalling vector core
- strategy: Upcast bf16/int8 to float32 before elementwise ops, downcast only on final output write
- strategy: Fuse transpose into `lax.dot_general` contraction dimensions instead of materializing transposed operands
- strategy: Arrange grid iteration order so consecutive invocations reuse already-resident input slices
- strategy: Increase pipeline buffer count beyond double buffering to hide memory latency for bandwidth-bound kernels
- strategy: Generate random numbers inside kernel via hardware PRNG with key in SMEM instead of passing precomputed arrays
- strategy: Avoid singleton dimensions in last two array axes to prevent full-tile waste per element
- strategy: Reduce along second-to-last dimension rather than last dimension when possible
- strategy: Prefer add/multiply over exp/tanh/division; restructure math to minimize expensive elementwise ops
- strategy: Tune block sizes jointly — systematically vary BM, BN, BK together under the VMEM budget constraint (16 MiB including double-buffering and scratch), considering the tradeoff between fewer grid tiles (larger blocks → less pipeline overhead) and fitting in VMEM, and whether eliminating a K-reduction loop (BK = full K) with a smaller BN outperforms a K-tiled approach with larger BN
- strategy: Compute arithmetic intensity accounting for tiling amplification to predict compute-bound vs memory-bound regime
- strategy: Minimize control flow inside kernels; consolidate into single basic blocks to avoid unrolling overhead
- strategy: Pass all data as explicit kernel inputs with BlockSpec instead of closing over constants
- strategy: Use `pltpu.VMEM` scoped scratch buffers for temporary storage within kernel lifetime
- strategy: Balance block size against pipeline depth to amortize startup/drain bubble cost over enough iterations
- strategy: Explicitly initialize accumulator buffers to zero on first reduction iteration since SRAM starts undefined
