@nki.jit
def solution(hidden_states, lm_head_weight):
    """
    Optimized for constant shapes (batch 32 version):
      hidden_states: [32, 1, 2048]
      lm_head_weight: [2048, 64128]
    Returns:
      out: [32, 1, 64128]

    Key optimizations:
      1) Precompute x^T tiles once per batch: hidden_states[b,1,128] -> transpose -> xT[128,1] in SBUF.
      2) Load RHS in larger contiguous N-blocks (2048 columns) and slice into 4x512 for matmul.
      3) Load full hidden_states per batch (1, 2048) at once to avoid small inefficient DMAs.
    """
    # ---- Constants (fixed shapes) ----
    B = 32
    K = 2048
    N = 64128

    TILE_K = 128
    TILE_N = 512

    NUM_K = K // TILE_K              # 16
    FULL_CHUNK_TILES = 4             # 4 * 512 = 2048 columns at a time
    FULL_CHUNK_N = FULL_CHUNK_TILES * TILE_N
    FULL_CHUNKS = N // FULL_CHUNK_N  # 31
    TAIL_COLS = N - FULL_CHUNKS * FULL_CHUNK_N  # 640
    TAIL_TILES = (TAIL_COLS + TILE_N - 1) // TILE_N  # 2
    TAIL_CHUNK_N = TAIL_TILES * TILE_N  # 1024

    # Partition dimension for M is 1 (single row / single token)
    P_M = 1

    # Output on HBM: keep final shape [32,1,N]
    out = nl.ndarray((B, 1, N), dtype=hidden_states.dtype, buffer=nl.shared_hbm)

    # ---- Process each batch element ----
    for b in nl.affine_range(B):
        # Load all hidden_states for this batch at once: (32, 1, 2048) -> (1, 2048) tile in SBUF
        # Using pure basic indexing [b, 0:1, :] avoids mixed indexing errors.
        x_full = nl.load(hidden_states[b, 0:1, :]) 

        # ---- Precompute xT tiles once (SBUF residency) ----
        # xT_all[k] is a tile of shape [128, 1] with partition dim = 128 (K tile)
        xT_all = nl.ndarray((NUM_K, nl.par_dim(TILE_K), P_M),
                            dtype=hidden_states.dtype, buffer=nl.sbuf)

        for k in nl.affine_range(NUM_K):
            k0 = k * TILE_K

            # Slice x block from the full loaded tile in SBUF
            # Uses basic indexing on SBUF tile: [:, slice]
            x_blk = x_full[:, nl.ds(k0, TILE_K)]  # (1, 128) on SBUF

            # Transpose P/F => (128,1) in PSUM, then copy to SBUF
            xT_psum = nisa.nc_transpose(x_blk, engine=nisa.tensor_engine)  # (128, 1) on PSUM
            xT_all[k] = nisa.tensor_copy(xT_psum, dtype=hidden_states.dtype)  # (128, 1) on SBUF

        # ---- Main compute: full chunks (no masks needed) ----
        for chunk in nl.affine_range(FULL_CHUNKS):
            n_base = chunk * FULL_CHUNK_N  # start column in N

            # 4 independent accumulators for 4x512 tiles in this chunk
            acc = nl.zeros((FULL_CHUNK_TILES, nl.par_dim(P_M), TILE_N),
                           dtype=nl.float32, buffer=nl.psum)

            for k in nl.affine_range(NUM_K):
                k0 = k * TILE_K

                # Load a large RHS block [128,2048] once, slice into 4x [128,512]
                # Uses pure basic indexing: [nl.ds, nl.ds]
                w_big = nl.load(lm_head_weight[nl.ds(k0, TILE_K), nl.ds(n_base, FULL_CHUNK_N)])  # (128,2048)

                xT = xT_all[k]  # (128,1) stationary for this k

                # 4 matmuls into 4 accumulators
                for sub in range(FULL_CHUNK_TILES):
                    w_tile = w_big[:, nl.ds(sub * TILE_N, TILE_N)]  # (128,512)
                    acc[sub] += nisa.nc_matmul(stationary=xT, moving=w_tile)  # (1,512) on PSUM

            # Store 4 output tiles back to HBM
            for sub in range(FULL_CHUNK_TILES):
                n0 = n_base + sub * TILE_N
                acc_sb = nisa.tensor_copy(acc[sub], dtype=hidden_states.dtype)  # PSUM->SBUF cast
                
                # Use pure basic indexing for store: [b, 0:1, nl.ds]
                # Since n0 is aligned and fully in-bounds in main loop, this is safe.
                nl.store(out[b, 0:1, nl.ds(n0, TILE_N)], value=acc_sb)

        # ---- Tail chunk (masked load/store) ----
        if TAIL_COLS != 0:
            tail_base = FULL_CHUNKS * FULL_CHUNK_N  # 63488

            acc_tail = nl.zeros((TAIL_TILES, nl.par_dim(P_M), TILE_N),
                                dtype=nl.float32, buffer=nl.psum)

            # Indices for masked RHS load (avoid OOB on N)
            i_n_tail_big = tail_base + nl.arange(TAIL_CHUNK_N)[None, :]  # (1,1024)

            for k in nl.affine_range(NUM_K):
                k0 = k * TILE_K
                i_k = k0 + nl.arange(TILE_K)[:, None]  # (128,1)

                # Mask only on N dimension (K is exact)
                mask_w = (i_n_tail_big < N)
                
                # Pure advanced indexing: [i_k, i_n_tail_big]
                w_tail_big = nl.load(lm_head_weight[i_k, i_n_tail_big], mask=mask_w)  # (128,1024)

                xT = xT_all[k]

                for sub in range(TAIL_TILES):  # 2 tiles
                    w_tile = w_tail_big[:, nl.ds(sub * TILE_N, TILE_N)]  # (128,512)
                    acc_tail[sub] += nisa.nc_matmul(stationary=xT, moving=w_tile)

            # Masked stores for the last partial 512-tile
            # We need purely advanced indexing here to support masking/OOB handling properly.
            # Create indices for all dimensions
            for sub in range(TAIL_TILES):
                n0 = tail_base + sub * TILE_N
                i_n = n0 + nl.arange(TILE_N)[None, :]  # (1,512)
                mask_store = (i_n < N)

                acc_sb = nisa.tensor_copy(acc_tail[sub], dtype=hidden_states.dtype)
                
                # Create batch and M indices for advanced indexing
                # For batch dimension: create a (1,1) array with value b
                # Use arange to create base, then adjust to get b
                i_b_base = nl.arange(1)[:, None]  # (1, 1) with value 0
                i_b_adv = i_b_base * 0 + b  # (1, 1) with value b
                i_m_p_adv = nl.arange(P_M)[:, None]  # (1, 1), Value 0
                
                # Pure advanced indexing: out[i_b_adv, i_m_p_adv, i_n]
                # Broadcasting: (1,1), (1,1), (1,512) -> (1,512) tile
                nl.store(out[i_b_adv, i_m_p_adv, i_n], value=acc_sb, mask=mask_store)

    return out
