# Magic number to replace -inf similar to what Tensorizer uses
NEG_INF = -9984.0

@nki.jit
def transpose_p_local(
    p_local_transposed,
    p_local,
    Q_TILE_SIZE,
    LARGE_KV_TILE_SIZE,
):
    assert p_local.shape == (Q_TILE_SIZE, LARGE_KV_TILE_SIZE)
    B_P_SIZE = nl.tile_size.pmax
    REDUCTION_SIZE = min(B_P_SIZE, LARGE_KV_TILE_SIZE)
    B_F_SIZE = nl.tile_size.gemm_moving_fmax
    for i in nl.affine_range(LARGE_KV_TILE_SIZE // B_F_SIZE):
        p_local_t_tmp = nl.ndarray(
            (
                nl.par_dim(REDUCTION_SIZE),
                B_F_SIZE // REDUCTION_SIZE * Q_TILE_SIZE,
            ),
            buffer=nl.psum,
            dtype=np.float32,
        )
        for j in nl.affine_range(B_F_SIZE // REDUCTION_SIZE):

            j_128_slice = nl.ds(j * Q_TILE_SIZE, Q_TILE_SIZE)
            i_j_128_slice = nl.ds(
                i * B_F_SIZE + j * REDUCTION_SIZE, REDUCTION_SIZE
            )
            p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
                p_local[:, i_j_128_slice],
                engine=nisa.tensor_engine,
            )
        p_local_transposed[
            :,
            nl.ds(
                i * (B_F_SIZE // REDUCTION_SIZE * Q_TILE_SIZE),
                (B_F_SIZE // REDUCTION_SIZE * Q_TILE_SIZE),
            ),
        ] = nl.copy(p_local_t_tmp, dtype=p_local_transposed.dtype)

@nki.jit
def solution(
    q,
    k,
    v,
    olm_prev,
    kernel_dtype,
    acc_type,
    tile_mask,
    q_tile_idx=None,
    Q_TILE_SIZE=128,
    LARGE_KV_TILE_SIZE=16384,
    B_P_SIZE=128,
    B_F_SIZE=512,
    B_D_SIZE=128,
):
    """
    The flash attention core function to calculate self attention between a tile
    of q and a block of K and V.
    q: (B_D_SIZE, LARGE_KV_TILE_SIZE)
    k: (B_D_SIZE, LARGE_KV_TILE_SIZE)
    v: (B_P_SIZE, LARGE_KV_TILE_SIZE // B_P_SIZE, B_D_SIZE)
    The results are returned in olm
    olm: (Q_TILE_SIZE, B_D_SIZE + 2)
    """
    assert (
        LARGE_KV_TILE_SIZE % B_P_SIZE == 0
    ), f"{LARGE_KV_TILE_SIZE=} not divisive by {B_P_SIZE=}"
    assert (
        LARGE_KV_TILE_SIZE % B_F_SIZE == 0
    ), f"{LARGE_KV_TILE_SIZE=} not divisive by {B_F_SIZE=}"
    num_k_tile_per_large_tile = LARGE_KV_TILE_SIZE // B_F_SIZE

    qk_res_buf = nl.ndarray(
        (nl.par_dim(Q_TILE_SIZE), LARGE_KV_TILE_SIZE),
        buffer=nl.sbuf,
        dtype=acc_type,
    )
    max_local = nl.zeros(
        (nl.par_dim(Q_TILE_SIZE), num_k_tile_per_large_tile),
        dtype=acc_type,
    )
    for k_i in nl.affine_range(num_k_tile_per_large_tile):
        k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)

        # Apply causal masking: only compute when q_tile_idx * Q_TILE_SIZE >= k_i * B_F_SIZE
        multiplication_required_selection = (
            q_tile_idx * Q_TILE_SIZE >= k_i * B_F_SIZE
        )

        if multiplication_required_selection:
            qk_psum = nl.ndarray(
                (nl.par_dim(Q_TILE_SIZE), B_F_SIZE),
                dtype=np.float32,
                buffer=nl.psum,
            )  # (128, 512)
            q_local_tile = nl.load(q[:, q_tile_idx * Q_TILE_SIZE:(q_tile_idx + 1) * Q_TILE_SIZE], dtype=kernel_dtype)
            k_local_tile = nl.load(k[:, k_i_b_f_slice], dtype=kernel_dtype)
            qk_psum[:, :] = nl.matmul(
                q_local_tile, k_local_tile, transpose_x=True
            )  # (p(128), 512)
            tile_mask_local_tile = nl.load(tile_mask[:, k_i_b_f_slice])
            qk_res_buf[:, k_i_b_f_slice] = nl.where(
                tile_mask_local_tile,
                qk_psum[:, nl.ds(0, B_F_SIZE)],
                NEG_INF,
                dtype=acc_type,
            )
            # Calculate max of the current tile
            max_local[:, k_i] = nisa.tensor_reduce(
                np.max,
                qk_res_buf[:, k_i_b_f_slice],
                axis=(1,),
                dtype=acc_type,
                negate=False,
            )
        else:
            qk_res_buf[:, k_i_b_f_slice] = NEG_INF
            max_local[:, k_i] = NEG_INF

    # Calculate max of the current tile
    max_ = nisa.tensor_reduce(
        np.max,
        max_local[:, :],
        axis=(1,),
        dtype=acc_type,
        negate=False,
    )

    olm_buffer = nl.ndarray((Q_TILE_SIZE, B_D_SIZE + 2), dtype=kernel_dtype, buffer=nl.sbuf)
    o_previous_scaled = nl.ndarray(
        (nl.par_dim(Q_TILE_SIZE), B_D_SIZE),
        dtype=kernel_dtype,
    )

    m_previous = nl.load(olm_prev[:, B_D_SIZE + 1], dtype=kernel_dtype)
    m_current_neg = nisa.tensor_scalar(
        max_,
        nl.maximum,
        m_previous,
        op1=nl.multiply,
        operand1=-1,
    )

    p_local = nl.ndarray(
        (nl.par_dim(Q_TILE_SIZE), LARGE_KV_TILE_SIZE),
        dtype=kernel_dtype,
    )
    REDUCTION_TILE = min(2048, LARGE_KV_TILE_SIZE // 2)

    p_partial_sum = nl.ndarray(
        (nl.par_dim(Q_TILE_SIZE), LARGE_KV_TILE_SIZE // REDUCTION_TILE),
        dtype=acc_type,
    )

    for k_r_i in nl.affine_range(LARGE_KV_TILE_SIZE // REDUCTION_TILE):
        k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)

        # compute exp(qk - max)
        # Compute partial row - tile sum of exp(qk - max))
        # FIXME : Use activation accumulate to accumulate over k_r_i loop ?
        p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
            np.exp,
            qk_res_buf[:, k_r_i_reduce_slice],
            bias=m_current_neg,
            scale=1.0,
            reduce_op=nl.add,
            reduce_res=p_partial_sum[:, k_r_i],
            dtype=kernel_dtype,
        )

    ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)

    p_local_transposed = nl.ndarray(
        (nl.par_dim(B_P_SIZE), LARGE_KV_TILE_SIZE // B_P_SIZE * Q_TILE_SIZE),
        dtype=kernel_dtype,
    )
    transpose_p_local(
        p_local_transposed=p_local_transposed,
        p_local=p_local,
        Q_TILE_SIZE=Q_TILE_SIZE,
        LARGE_KV_TILE_SIZE=LARGE_KV_TILE_SIZE,
    )

    pv_psum = nl.zeros(
        (nl.par_dim(Q_TILE_SIZE), B_D_SIZE),
        dtype=np.float32,
        buffer=nl.psum,
    )
    v_local = nl.load(v[:, :, :], dtype=kernel_dtype)
    for k_i in nl.affine_range(LARGE_KV_TILE_SIZE // B_P_SIZE):
        pv_psum[:, :] += nl.matmul(
            p_local_transposed[:, nl.ds(k_i * Q_TILE_SIZE, Q_TILE_SIZE)],
            v_local[:, k_i, :],
            transpose_x=True,
        )  # (128, 128) (p(Br), d)

    # Compute scaling factor
    alpha = nisa.activation(
        np.exp,
        m_previous,
        bias=m_current_neg,
        scale=1.0,
    )

    olm_buffer[:, B_D_SIZE + 1] = nisa.activation(
        nl.copy,
        m_current_neg,
        scale=-1.0,
    )
    o_previous_scaled[...] = nl.multiply(
        nl.load(olm_prev[:, nl.ds(0, B_D_SIZE)], dtype=kernel_dtype),
        alpha,
    )
    olm_buffer[:, nl.ds(0, B_D_SIZE)] = nl.add(o_previous_scaled, pv_psum)

    l_prev = nl.load(olm_prev[:, B_D_SIZE], dtype=kernel_dtype) * alpha
    olm_buffer[:, B_D_SIZE] = l_prev + ps
    olm = nl.ndarray((Q_TILE_SIZE, B_D_SIZE + 2), dtype=kernel_dtype, buffer=nl.shared_hbm)
    nl.store(olm, olm_buffer)
    return olm
