def matmul_kernel(x_ref, y_ref, z_ref, acc_ref):
    """
    Multi-level tiled matmul kernel with inner K-dimension loop.
    
    Uses scratch memory for accumulation to enable inner-loop tiling
    over the K dimension within each block.
    """
    # Get block dimensions from refs
    bm, bk = x_ref.shape
    _, bn = y_ref.shape
    
    # Inner tile size for K dimension
    sk = 32
    
    # Initialize accumulator on first K-block
    @pl.when(pl.program_id(2) == 0)
    def _():
        acc_ref[...] = jnp.zeros_like(acc_ref[...])
    
    # Inner loop over K dimension with smaller tiles
    # This increases register-level data reuse
    def inner_loop_body(k_idx, acc):
        # Slice inner tiles from the loaded VMEM blocks
        x_tile = jax.lax.dynamic_slice(x_ref[...], (0, k_idx * sk), (bm, sk))
        y_tile = jax.lax.dynamic_slice(y_ref[...], (k_idx * sk, 0), (sk, bn))
        # Accumulate partial matrix multiplication
        return acc + jnp.dot(x_tile, y_tile, preferred_element_type=jnp.float32)
    
    # Number of inner iterations
    num_inner_iters = bk // sk
    
    # Load current accumulator
    acc = acc_ref[...]
    
    # Execute inner loop using lax.fori_loop for efficient compilation
    acc = jax.lax.fori_loop(0, num_inner_iters, inner_loop_body, acc)
    
    # Store back to accumulator
    acc_ref[...] = acc
    
    # Copy to output on last K-block iteration
    grid_k = pl.num_programs(2)
    @pl.when(pl.program_id(2) == grid_k - 1)
    def _():
        z_ref[...] = acc_ref[...]


def matmul_kernel_simple(x_ref, y_ref, z_ref, acc_ref):
    """
    Simplified kernel with explicit inner tiling.
    """
    bm, bk = x_ref.shape
    _, bn = y_ref.shape
    
    # Initialize on first K iteration
    @pl.when(pl.program_id(2) == 0)
    def _():
        acc_ref[...] = jnp.zeros((bm, bn), dtype=jnp.float32)
    
    # Perform the matrix multiply and accumulate
    # Using preferred_element_type for better precision
    partial = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
    acc_ref[...] += partial
    
    # Write output on last K iteration
    grid_k = pl.num_programs(2)
    @pl.when(pl.program_id(2) == grid_k - 1)
    def _():
        z_ref[...] = acc_ref[...].astype(z_ref.dtype)


def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 256,
    bn: int = 128,
):
    """
    Optimized matrix multiplication with multi-level tiling.
    
    Uses larger block sizes and scratch memory for accumulation
    to improve arithmetic intensity and reduce HBM traffic.
    """
    m, k = x.shape
    k2, n = y.shape
    assert k == k2, f"Inner dimensions must match: {k} != {k2}"
    assert m % bm == 0 and k % bk == 0 and n % bn == 0, (
        f"Shapes must be divisible by block sizes: "
        f"(m,k,n)=({m},{k},{n}) vs (bm,bk,bn)=({bm},{bk},{bn})"
    )

    return pl.pallas_call(
        matmul_kernel_simple,
        out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, kk: (i, kk)),
            pl.BlockSpec((bk, bn), lambda i, j, kk: (kk, j)),
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, kk: (i, j)),
        grid=(m // bm, n // bn, k // bk),
        scratch_shapes=[pltpu.VMEM((bm, bn), dtype=jnp.float32)],
        compiler_params=pltpu.CompilerParams(
            dimension_semantics=("parallel", "parallel", "arbitrary")
        ),
    )(x, y)


@jax.jit
def test(x, y):
    """
    Test function with optimized block size selection.
    
    Selects larger block sizes when possible to improve
    arithmetic intensity and reduce loop overhead.
    """
    m, k = x.shape
    _, n = y.shape
    
    # Select optimal bk based on K dimension divisibility
    # Prefer larger blocks for better arithmetic intensity
    if k % 512 == 0:
        bk = 512
    elif k % 256 == 0:
        bk = 256
    else:
        bk = 128
    
    # Select optimal bm based on M dimension
    # Larger blocks improve data reuse
    if m % 256 == 0 and m >= 256:
        bm = 256
    else:
        bm = 128
    
    # Select optimal bn based on N dimension
    if n % 256 == 0 and n >= 256:
        bn = 256
    else:
        bn = 128
    
    out = matmul(x, y, bm=bm, bk=bk, bn=bn)
    return out
