@nki.jit
def solution(Q, K, V, past_key_value, attention_mask):
    """
    NKI implementation of token generation attention.

    Shapes (constant for this kernel):
      Q: [32, 16, 1, 64]
      K: [32, 4, 1, 64]
      V: [32, 4, 1, 64]
      past_key_value[0]: [32, 4, 512, 64] (K_prior)
      past_key_value[1]: [32, 4, 512, 64] (V_prior)
      attention_mask: [32, 16, 1, 512] (bool)
    """
    kernel_output = nl.ndarray(Q.shape, dtype=Q.dtype, buffer=nl.shared_hbm)

    # Constants (compile-time for fixed shapes)
    B_SZ = 32
    NUM_HEADS = 16
    NUM_KV_HEADS = 4
    HEAD_DIM = 64
    SEQ_LEN = 512
    GROUP_SIZE = NUM_HEADS // NUM_KV_HEADS  # 4

    inv_sqrt_scale = 1.0 / 8.0  # 1/sqrt(64)
    neg_inf_val = -1.0e4

    pkv_k = past_key_value[0]
    pkv_v = past_key_value[1]

    for b in nl.affine_range(B_SZ):
        for kv_h in nl.affine_range(NUM_KV_HEADS):

            # ------------------------------------------------------------
            # 1) Build K_prior^T in SBUF: [64(P), 512(F)]
            # ------------------------------------------------------------
            k_prior_T = nl.ndarray((nl.par_dim(HEAD_DIM), SEQ_LEN),
                                   dtype=pkv_k.dtype, buffer=nl.sbuf)

            for i in nl.affine_range(SEQ_LEN // 128):
                offset = i * 128

                # Load [128(P), 64(F)]
                k_chunk = nl.load(pkv_k[b, kv_h, nl.ds(offset, 128)])

                # Transpose to [64(P), 128(F)] (PSUM -> SBUF)
                k_chunk_T_psum = nisa.nc_transpose(k_chunk, engine=nisa.tensor_engine)
                k_chunk_T_sb = nisa.tensor_copy(k_chunk_T_psum, dtype=pkv_k.dtype)

                # Scatter into [64, 512]
                k_prior_T[:, nl.ds(offset, 128)] = k_chunk_T_sb

            # ------------------------------------------------------------
            # 2) Active K^T: K_active^T [64(P), 1(F)]
            # ------------------------------------------------------------
            k_active_tile = nl.load(K[b, kv_h, 0:1, :])  # [1(P), 64(F)]
            k_active_T_psum = nisa.nc_transpose(k_active_tile, engine=nisa.tensor_engine)  # [64,1]
            k_active_T = nisa.tensor_copy(k_active_T_psum, dtype=K.dtype)

            # ------------------------------------------------------------
            # 3) Process 4 Q heads for this KV head
            # ------------------------------------------------------------
            h_start = kv_h * GROUP_SIZE

            # Q group: [4(P), 64(F)]
            q_group = nl.load(Q[b, nl.ds(h_start, GROUP_SIZE), 0, :])

            # Stationary for matmul: Q^T as [64(P), 4(F)]
            q_stat_psum = nisa.nc_transpose(q_group, engine=nisa.tensor_engine)
            q_stat = nisa.tensor_copy(q_stat_psum, dtype=Q.dtype)

            # ------------------------------------------------------------
            # Scores
            # ------------------------------------------------------------
            scores_prior_psum = nisa.nc_matmul(q_stat, k_prior_T)   # [4,512] in PSUM
            scores_prior = nisa.tensor_copy(scores_prior_psum, dtype=nl.float32)  # [4,512] SBUF fp32

            scores_active_psum = nisa.nc_matmul(q_stat, k_active_T)  # [4,1] in PSUM
            scores_active = nisa.tensor_copy(scores_active_psum, dtype=nl.float32)  # [4,1] SBUF fp32

            # Scale
            scores_prior = nl.multiply(scores_prior, inv_sqrt_scale)
            scores_active = nl.multiply(scores_active, inv_sqrt_scale)

            # Mask prior scores
            mask_tile = nl.load(attention_mask[b, nl.ds(h_start, GROUP_SIZE), 0, :])  # [4,512] bool
            scores_prior = nl.where(mask_tile, scores_prior, neg_inf_val)

            # ------------------------------------------------------------
            # Softmax (numerically stable) but DELAY normalization division
            # ------------------------------------------------------------
            max_prior = nl.max(scores_prior, axis=1, keepdims=True)  # [4,1]
            max_active = scores_active                               # [4,1]
            max_global = nl.maximum(max_prior, max_active)           # [4,1]

            exp_prior = nl.exp(nl.subtract(scores_prior, max_global))    # [4,512] fp32
            exp_active = nl.exp(nl.subtract(scores_active, max_global))  # [4,1] fp32

            sum_prior = nl.sum(exp_prior, axis=1, keepdims=True)  # [4,1] fp32
            sum_active = exp_active                               # [4,1] fp32
            denominator = nl.add(sum_prior, sum_active)           # [4,1] fp32

            # Use unnormalized exp() values for value accumulation (cast once)
            exp_prior_w = nisa.tensor_copy(exp_prior, dtype=Q.dtype)     # [4,512] bf16/fp16
            exp_active_w = nisa.tensor_copy(exp_active, dtype=Q.dtype)   # [4,1] bf16/fp16

            # ------------------------------------------------------------
            # Value accumulation: (exp * V) / denom
            # ------------------------------------------------------------
            out_accum = nl.zeros((GROUP_SIZE, HEAD_DIM), dtype=nl.float32, buffer=nl.psum)  # [4,64] fp32

            # Prior V contribution: sum over 512 in 128-chunks
            for i in nl.affine_range(SEQ_LEN // 128):
                offset = i * 128

                # weights chunk [4,128] -> transpose to stationary [128,4]
                w_chunk = exp_prior_w[:, nl.ds(offset, 128)]
                w_stat_psum = nisa.nc_transpose(w_chunk, engine=nisa.tensor_engine)  # [128,4]
                w_stat = nisa.tensor_copy(w_stat_psum, dtype=Q.dtype)

                # V chunk [128,64]
                v_chunk = nl.load(pkv_v[b, kv_h, nl.ds(offset, 128)])

                # [4,128] @ [128,64] -> [4,64]
                out_accum += nisa.nc_matmul(w_stat, v_chunk)

            # Active V contribution: [4,1] @ [1,64] -> [4,64]
            w_active_stat_psum = nisa.nc_transpose(exp_active_w, engine=nisa.tensor_engine)  # [1,4]
            w_active_stat = nisa.tensor_copy(w_active_stat_psum, dtype=Q.dtype)

            v_active = nl.load(V[b, kv_h, 0:1, :])  # [1,64]
            out_accum += nisa.nc_matmul(w_active_stat, v_active)

            # ------------------------------------------------------------
            # Normalize once at the end: out = out_accum / denominator
            # ------------------------------------------------------------
            out_fp32 = nisa.tensor_copy(out_accum, dtype=nl.float32)     # PSUM -> SBUF fp32
            out_norm_fp32 = nl.divide(out_fp32, denominator)            # broadcast denom [4,1] -> [4,64]
            out_tile = nisa.tensor_copy(out_norm_fp32, dtype=Q.dtype)    # cast to output dtype

            nl.store(kernel_output[b, nl.ds(h_start, GROUP_SIZE), 0, :], value=out_tile)

    return kernel_output
