# SPDX-License-Identifier: Apache-2.0

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
import torch
from neuronxcc import nki
from neuronxcc.nki.language import par_dim


def ceil_div(a, b):
    return (a + b - 1) // b


def is_power_of_2(x):
    return x > 0 and (x & (x - 1)) == 0


@nki.jit
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
    """
    Load block tables from HBM into SRAM

    `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
    In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
    """
    B_P_SIZE = 128

    # reshape as `(num_tiles, num_blocks_per_tile)`
    assert len(block_tables_hbm.shape) == 1
    (num_total_blocks, ) = block_tables_hbm.shape
    assert num_blocks_per_tile * num_tiles == num_total_blocks
    block_tables_hbm = block_tables_hbm.reshape(
        (num_tiles, num_blocks_per_tile))

    block_tables_sbuf = nl.zeros(
        (ceil_div(num_tiles,
                  B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
        dtype=nl.int32,
    )
    for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
        i_p = nl.arange(B_P_SIZE)[:, None]
        i_f = nl.arange(num_blocks_per_tile)[None, :]
        block_tables_sbuf[i, i_p, i_f] = nl.load(
            block_tables_hbm[i_p + i * B_P_SIZE, i_f],
            dtype=nl.int32,
            mask=(i_p + i * B_P_SIZE < num_tiles),
        )
    return block_tables_sbuf


@nki.jit
def transform_block_tables_for_indirect_load(
    block_tables,
    block_size_tiling_factor,
    num_head,
    head_id,
):
    """
    This function does two things:
    1. calculate new `block_tables` for a `head_id` after flattening
    `num_block`, `num_head`, and `block_size_tiling_factor` dimensions
    2. transpose the result so that `block_table` for each tile is mapped to
    SBUF Partition dimension for vectorized DMA

    Tiling trick to further improve DMA performance:
    Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
    blocks of a given `head_id` from HBM, the load `cache[block_tables,
    head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
    fully utilize hardware parallelization. The solution is to tile `block_size`
    into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
    block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
    `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. 

    Note:
    We don't further tile D dimension as small DMA size also hurts performance.
    """
    B_P_SIZE = 128
    num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
        block_tables.shape)
    assert num_tiles_per_partition == B_P_SIZE
    assert is_power_of_2(
        num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"

    num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
    block_tables_transposed = nl.ndarray(
        (
            num_loads,
            par_dim(B_P_SIZE),
            num_partitions * num_tiles_per_partition,
        ),
        dtype=nl.int32,
    )

    # prepare iota ahead of time to avoid repeatedly using Gpsimd
    if num_head > 1:
        head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
        head_id = nl.transpose(
            head_id.broadcast_to((1, num_tiles_per_partition)))
        if num_blocks_per_tile > 1:
            head_id = head_id.broadcast_to(
                (num_tiles_per_partition, num_blocks_per_tile))

    if block_size_tiling_factor > 1:
        broadcast_shape = (
            num_tiles_per_partition,
            num_blocks_per_tile,
            block_size_tiling_factor,
        )
        offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
                           dtype=nl.int32).broadcast_to(broadcast_shape)

    for partition_id in nl.affine_range(num_partitions):
        block_tables_partition = block_tables[partition_id]
        if num_head > 1:
            # fuse num_block and num_head dimension
            block_tables_partition = block_tables_partition * num_head + head_id

        if block_size_tiling_factor > 1:
            # need to apply block size tiling trick
            assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
            block_tables_partition = ((block_tables_partition *
                                       block_size_tiling_factor).reshape(
                                           (num_tiles_per_partition,
                                            num_blocks_per_tile,
                                            1)).broadcast_to(broadcast_shape))
            new_block_tables = block_tables_partition + offset
            new_block_tables = new_block_tables.reshape(
                (num_tiles_per_partition, B_P_SIZE))
        else:
            new_block_tables = block_tables_partition

        # transpose the block table so that it can be used by vector DGE
        for i in nl.affine_range(num_loads):
            i_p = nl.arange(B_P_SIZE)[:, None]
            i_f = (partition_id * num_tiles_per_partition +
                   nl.arange(num_tiles_per_partition)[None, :])
            block_tables_transposed[i, i_p, i_f] = nl.transpose(
                new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
    return block_tables_transposed


@nki.jit
def load_kv_tile_from_cache(
    cur_k_tile,
    cur_v_tile,
    kv_cache,
    block_tables,
    large_k_tile_idx,
    num_blocks_per_large_tile,
    tiled_block_size,
    B_P_SIZE,
    B_D_SIZE,
):
    """
    Load KV cache and transform Key and Value into layout required by Matmul

    Vectorized DMA Load layout:
    Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)

    Layout used by attention matmuls:
    Key: (par_dim(B_D_SIZE), seqlen_kv)
    Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
           equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
    """
    # load key cache
    num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
    for load_idx in nl.affine_range(num_loads):
        i_p = nl.arange(B_P_SIZE)[:, None]
        i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
        loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
                                                  large_k_tile_idx], i_f])
        if cur_k_tile.dtype != loaded.dtype:
            loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
        # Transpose SBUF tensor using PE
        for tb_i in nl.affine_range(tiled_block_size):
            cur_k_tile[
                :,
                nl.ds(
                    load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
                    B_P_SIZE,
                ),
            ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])

    # load value cache
    for load_idx in nl.affine_range(num_loads):
        loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
                                                  large_k_tile_idx], i_f])
        if cur_v_tile.dtype != loaded.dtype:
            loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
        i_p = nl.arange(B_P_SIZE)[:, None]
        i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
        cur_v_tile[
            :,
            nl.ds(
                load_idx * tiled_block_size * B_D_SIZE,
                tiled_block_size * B_D_SIZE,
            ),
        ] = loaded


@nki.jit
def transpose_p_local(p_local_transposed,
                      p_local,
                      LARGE_TILE_SZ,
                      B_F_SIZE=512):
    for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
        if nisa.get_nc_version() == nisa.nc_version.gen3:
            p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
                                       buffer=nl.sbuf,
                                       dtype=p_local.dtype)
        else:
            p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
                                       buffer=nl.psum,
                                       dtype=np.float32)

        for j in nl.affine_range(B_F_SIZE // 128):
            j_128_slice = nl.ds(j * 128, 128)
            i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)

            if nisa.get_nc_version() == nisa.nc_version.gen3:
                p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
                    p_local[:, i_j_128_slice])
            else:
                p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
                    p_local[:, i_j_128_slice])

        p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
            p_local_t_tmp, dtype=p_local_transposed.dtype)


@nki.jit
def _flash_attention_core(
    q_local_tile,
    k,
    v,
    o_buffer,
    l_buffer,
    m_buffer,
    kernel_dtype,
    acc_type,
    tile_mask,
    use_causal_mask,
    q_tile_idx=None,
    initialize=False,
    LARGE_TILE_SZ=2048,
    B_P_SIZE=128,
    B_F_SIZE=512,
    B_D_SIZE=128,
    qk_res_buffer=None,
):
    """
    The flash attention core function to calculate self attention between a tile
    of q and a block of K and V.
    The q_local_tile has (B_P_SIZE, B_D_SIZE)
    The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
    be split into size B_F_SIZE tiles

    The results are stored in the following three buffers
    o_buffer: (B_P_SIZE, d)
    l_buffer: (B_P_SIZE, 1)
    m_buffer: (B_P_SIZE, 1)

    All IO buffers are in SBUF.
    """
    num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE

    qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
                            buffer=nl.sbuf,
                            dtype=acc_type)
    max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
                           dtype=acc_type)
    for k_i in nl.affine_range(num_k_tile_per_large_tile):
        k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)

        if use_causal_mask:
            # mask are used to only apply computation to the lower half of the
            # matrix, which reduce the arithmetic intensity by up to 50%
            multiplication_required_selection = (q_tile_idx * B_P_SIZE
                                                 >= k_i * B_F_SIZE)
        else:
            multiplication_required_selection = True

        if multiplication_required_selection:
            qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
                                 dtype=np.float32,
                                 buffer=nl.psum)  # (128, 512)
            qk_psum[:, :] = nl.matmul(q_local_tile,
                                      k[:, k_i_b_f_slice],
                                      transpose_x=True)  # (p(128), 512)
            qk_res_buf[:, k_i_b_f_slice] = nl.where(
                tile_mask[:, k_i_b_f_slice],
                qk_psum[:, nl.ds(0, B_F_SIZE)],
                -9984.0,
                dtype=acc_type,
            )
        else:
            qk_res_buf[:, k_i_b_f_slice] = -9984.0

        # Calculate max of the current tile
        max_local[:, k_i] = nisa.tensor_reduce(
            np.max,
            qk_res_buf[:, k_i_b_f_slice],
            axis=(1, ),
            dtype=acc_type,
            negate=False,
        )

    if qk_res_buffer is not None:
        qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])

    max_ = nisa.tensor_reduce(
        np.max,
        max_local[:, :],
        axis=(1, ),
        dtype=acc_type,
        negate=False,
    )

    o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
                                   dtype=o_buffer.dtype)

    if initialize:
        m_buffer[:, 0] = nl.copy(max_)
        m_current = max_
    else:
        m_previous = nl.copy(m_buffer[:, 0])
        m_buffer[:, 0] = nl.maximum(m_previous, max_)  # (128,1)

        m_current = m_buffer[:, 0]
        # Compute scaling factor
        alpha = nisa.activation(
            np.exp,
            m_previous,
            bias=-1 * m_current,
            scale=1.0,
        )
        o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)

    p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
                         dtype=kernel_dtype)
    REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)

    p_partial_sum = nl.ndarray(
        (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
        dtype=acc_type,
    )

    for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
        k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)

        # compute exp(qk - max)
        # Compute partial row - tile sum of exp(qk - max))
        # FIXME : Use activation accumulate to accumulate over k_r_i loop ?
        p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
            np.exp,
            qk_res_buf[:, k_r_i_reduce_slice],
            bias=-1 * m_current,
            scale=1.0,
            reduce_op=nl.add,
            reduce_res=p_partial_sum[:, k_r_i],
            dtype=kernel_dtype,
        )

    ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)

    p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
                                    dtype=kernel_dtype)
    transpose_p_local(
        p_local_transposed=p_local_transposed,
        p_local=p_local,
        LARGE_TILE_SZ=LARGE_TILE_SZ,
        B_F_SIZE=B_F_SIZE,
    )

    pv_psum = nl.zeros(
        (par_dim(B_P_SIZE), B_D_SIZE),
        dtype=np.float32,
        buffer=nl.psum,
    )
    for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
        pv_psum[:, :] += nl.matmul(
            p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
            v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
            transpose_x=True,
        )  # (128, 128) (p(Br), d)

    if initialize:
        o_buffer[:, :] = nl.copy(pv_psum[:, :])
        l_buffer[:, 0] = nl.add(nl.log(ps), max_)
    else:
        o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)

        l_prev = l_buffer[:, 0]
        l_exp = nl.add(
            nl.exp(nl.subtract(l_prev, m_current)),
            ps,
        )
        l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))


@nki.jit
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
    B_P_SIZE = 128
    B_D_SIZE = v_hbm_tile.shape[-1]
    loaded = nl.load(v_hbm_tile[
        nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
        :,
    ])
    if cur_v_tile.dtype != loaded.dtype:
        loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
    cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded


@nki.jit
def flash_paged_attention(
    query,
    key,
    value,
    kv_cache,
    block_tables,
    mask,
    softmax_scale=None,
    mixed_precision=True,
    LARGE_TILE_SZ=2048,
    return_debug_tensors=False,
):
    """
    Flash PagedAttention Forward Kernel.

    IO tensor layouts:
      - query: shape   (1, n_heads, d, seq_q)
      - key:   shape   (1, n_kv_heads, d, seq_k)
      - value: shape   (1, n_kv_heads, seq_v, d)
      - kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
      - block_tables: (num_active_blocks, )
      - mask: (seq_q, num_active_blocks * block_size + seq_q)
      - o: shape (1, n_heads, seq_q, d)

      - This kernel requires seq_k == seq_v
      - We use continuous batching by default, so the batch dimension is
        always 1, and different requests are concatenated along sequence
        dimension.
      - We use paged cache blocks (kv_cache) to store KV cache.

    IO tensor dtypes:
      - This kernel assumes all IO tensors have the same dtype except for
        block_tables (int32) and mask (int32)
      - If mixed_percision is True, then all Tensor Engine operation will be
        performed in bfloat16 and accumulation will be performed in float32.
        Otherwise the intermediates will be in the same type as the inputs.

    Compile-time Constants:
      - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
      - mixed_precision: flag to set non-matmul ops in fp32 precision, default
        is set to `true`, if false, we use same precision as input types
      - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
        computation reduction

    GQA support Notes:
      the spmd kernel for launching kernel should be on kv_heads instead of
      nheads

    Example usage:
      MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
        usage: `flash_fwd[b, h](q, k, v, ...)`
      GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
        usage: `flash_fwd[b, kv_h](q, k, v, ...)`
    """
    B_F_SIZE = 512
    B_P_SIZE = 128
    b, h, d, seqlen_q = query.shape
    B_D_SIZE = d
    n_tile_q = seqlen_q // B_P_SIZE  # since q will be loaded on tensor engine
    _, num_blocks, k_h, block_size, _ = kv_cache.shape
    q_h_per_k_h = h // k_h
    assert b == 1, f"invalid batch size {b=}"
    assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
    cache_shape = (2, num_blocks, k_h, block_size, d)
    assert (tuple(kv_cache.shape) == cache_shape
            ), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
    assert key is None or tuple(key.shape) == (
        1,
        k_h,
        d,
        seqlen_q,
    ), f"key shape {key.shape} mismatch!"
    assert value is None or tuple(value.shape) == (
        1,
        k_h,
        seqlen_q,
        d,
    ), f"value shape {value.shape} mismatch!"

    assert (
        nl.program_ndim() == 2
    ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
    batch_id = nl.program_id(axis=0)
    head_id = nl.program_id(axis=1)

    (num_active_blocks, ) = block_tables.shape
    context_kv_len = num_active_blocks * block_size
    assert (
        LARGE_TILE_SZ % B_F_SIZE == 0
    ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
    assert (context_kv_len % LARGE_TILE_SZ == 0
            ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"

    num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
    assert is_power_of_2(
        num_blocks_per_large_tile
    ), f"{num_blocks_per_large_tile=} is expected of be power of 2"
    if seqlen_q > B_F_SIZE:
        MAX_REDUCTION_TILE = 2048
        if seqlen_q // 2 > MAX_REDUCTION_TILE:
            assert (
                seqlen_q % MAX_REDUCTION_TILE == 0
            ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
        else:
            assert (seqlen_q % B_F_SIZE == 0
                    ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"

    kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
    acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
    softmax_scale = softmax_scale or (1.0 / (d**0.5))
    num_large_k_tile = context_kv_len // LARGE_TILE_SZ

    o = nl.ndarray((b, h, seqlen_q, d),
                   dtype=query.dtype,
                   buffer=nl.shared_hbm)
    hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
        None,
        None,
        None,
        None,
    )
    if return_debug_tensors:
        hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
                                  dtype=acc_type,
                                  buffer=nl.shared_hbm)
        hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
                                  dtype=acc_type,
                                  buffer=nl.shared_hbm)
        hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
                                dtype=acc_type,
                                buffer=nl.shared_hbm)
        qk_res_buffer = nl.zeros(
            (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
            dtype=acc_type,
            buffer=nl.sbuf,
            lazy_initialization=True,
        )
    block_tables_sbuf = load_block_tables(
        block_tables_hbm=block_tables,
        num_tiles=num_large_k_tile,
        num_blocks_per_tile=num_blocks_per_large_tile,
    )

    # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
    if num_blocks_per_large_tile < B_P_SIZE:
        # we checked num_blocks_per_tile is a power of 2
        assert B_P_SIZE % num_blocks_per_large_tile == 0
        block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
        # We assume block_size >= block_size_tiling_factor
        assert block_size % block_size_tiling_factor == 0
    else:
        block_size_tiling_factor = 1
    tiled_block_size = block_size // block_size_tiling_factor

    # Indirect DMA load must be placed along Partition Dimension
    block_tables_sbuf = transform_block_tables_for_indirect_load(
        block_tables_sbuf,
        block_size_tiling_factor=block_size_tiling_factor,
        num_head=k_h,
        head_id=head_id,
    )

    # Flatten KV cache to be 3D for loading into SBUF
    new_cache_shape = (
        2,
        num_blocks * k_h * block_size_tiling_factor,
        tiled_block_size * d,
    )
    kv_cache = kv_cache.reshape(new_cache_shape)

    # Global Flash Attention accumulators
    o_buffer = nl.zeros(
        (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
        dtype=acc_type,
        buffer=nl.sbuf,
        lazy_initialization=True,
    )
    l_buffer = nl.zeros(
        (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
        dtype=acc_type,
        buffer=nl.sbuf,
        lazy_initialization=True,
    )
    m_buffer = nl.zeros(
        (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
        dtype=acc_type,
        buffer=nl.sbuf,
        lazy_initialization=True,
    )

    for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
        num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
        cur_k_tile = nl.ndarray(
            (par_dim(B_D_SIZE), LARGE_TILE_SZ),
            dtype=kernel_dtype,
        )
        cur_v_tile = nl.ndarray(
            (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
            dtype=kernel_dtype,
        )
        load_kv_tile_from_cache(
            cur_k_tile=cur_k_tile,
            cur_v_tile=cur_v_tile,
            kv_cache=kv_cache,
            block_tables=block_tables_sbuf,
            large_k_tile_idx=large_k_tile_idx,
            num_blocks_per_large_tile=num_blocks_per_large_tile,
            tiled_block_size=tiled_block_size,
            B_P_SIZE=B_P_SIZE,
            B_D_SIZE=B_D_SIZE,
        )

        for i in nl.affine_range(n_tile_q):
            cur_mask = nl.load(mask[
                nl.ds(i * B_P_SIZE, B_P_SIZE),
                nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
            ])
            for i_q_h in nl.affine_range(q_h_per_k_h):
                q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
                q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
                q_sbuf_tile = nl.load(q_hbm_tile[:,
                                                 nl.ds(i *
                                                       B_P_SIZE, B_P_SIZE)])
                if q_sbuf_tile.dtype != kernel_dtype:
                    q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
                q_tile[:, :] = q_sbuf_tile * softmax_scale

                _flash_attention_core(
                    q_local_tile=q_tile,
                    k=cur_k_tile,
                    v=cur_v_tile,
                    o_buffer=o_buffer[i, i_q_h],
                    l_buffer=l_buffer[i, i_q_h],
                    m_buffer=m_buffer[i, i_q_h],
                    kernel_dtype=kernel_dtype,
                    acc_type=acc_type,
                    tile_mask=cur_mask,
                    use_causal_mask=False,
                    q_tile_idx=i,
                    initialize=large_k_tile_idx == 0,
                    LARGE_TILE_SZ=LARGE_TILE_SZ,
                    B_P_SIZE=B_P_SIZE,
                    B_F_SIZE=B_F_SIZE,
                    B_D_SIZE=B_D_SIZE,
                )

    # compute attention between input query, key and value
    if key is not None and value is not None:
        B_F_SIZE = min(seqlen_q, B_F_SIZE)
        LARGE_TILE_SZ = seqlen_q

        cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
                                dtype=kernel_dtype)
        cur_v_tile = nl.ndarray(
            (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
            dtype=kernel_dtype,
        )

        loaded = nl.load(key[batch_id, head_id, :, :])
        if loaded.dtype != kernel_dtype:
            loaded = nl.copy(loaded, dtype=kernel_dtype)
        cur_k_tile[:, :] = loaded

        v_hbm_tile = value[batch_id, head_id]
        for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
            load_v_tile(
                v_hbm_tile=v_hbm_tile,
                cur_v_tile=cur_v_tile,
                large_tile_idx=0,
                v_i=v_i,
                LARGE_TILE_SZ=LARGE_TILE_SZ,
            )

        for i in nl.affine_range(n_tile_q):
            cur_mask = nl.load(mask[
                nl.ds(i * B_P_SIZE, B_P_SIZE),
                nl.ds(context_kv_len, LARGE_TILE_SZ),
            ])
            for i_q_h in nl.affine_range(q_h_per_k_h):

                q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
                q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
                q_sbuf_tile = nl.load(q_hbm_tile[:,
                                                 nl.ds(i *
                                                       B_P_SIZE, B_P_SIZE)])
                if q_sbuf_tile.dtype != kernel_dtype:
                    q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
                q_tile[:, :] = q_sbuf_tile * softmax_scale
                _flash_attention_core(
                    q_local_tile=q_tile,
                    k=cur_k_tile,
                    v=cur_v_tile,
                    o_buffer=o_buffer[i, i_q_h],
                    l_buffer=l_buffer[i, i_q_h],
                    m_buffer=m_buffer[i, i_q_h],
                    kernel_dtype=kernel_dtype,
                    acc_type=acc_type,
                    tile_mask=cur_mask,
                    use_causal_mask=True,
                    q_tile_idx=i,
                    initialize=False,
                    LARGE_TILE_SZ=LARGE_TILE_SZ,
                    B_P_SIZE=B_P_SIZE,
                    B_F_SIZE=B_F_SIZE,
                    B_D_SIZE=B_D_SIZE,
                    qk_res_buffer=(qk_res_buffer[i, i_q_h]
                                   if qk_res_buffer is not None else None),
                )

    # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
    for i_q_h in nl.affine_range(q_h_per_k_h):
        for i in nl.affine_range(n_tile_q):
            out = nl.multiply(
                o_buffer[i, i_q_h],
                nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
                dtype=kernel_dtype,
            )

            nl.store(
                o[
                    batch_id,
                    head_id * q_h_per_k_h + i_q_h,
                    nl.ds(i * B_P_SIZE, B_P_SIZE),
                    :,
                ],
                out,
            )
            # maximum and summation statistics
            if return_debug_tensors:
                nl.store(
                    hbm_m_buffer[
                        batch_id,
                        head_id * q_h_per_k_h + i_q_h,
                        nl.ds(i * B_P_SIZE, B_P_SIZE),
                    ],
                    m_buffer[i, i_q_h, :, :],
                )
                nl.store(
                    hbm_l_buffer[
                        batch_id,
                        head_id * q_h_per_k_h + i_q_h,
                        nl.ds(i * B_P_SIZE, B_P_SIZE),
                    ],
                    l_buffer[i, i_q_h],
                )
                nl.store(
                    hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
                    qk_res_buffer[batch_id, i_q_h, :, :],
                )

    if return_debug_tensors:
        return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
    return o


def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
    """
    Reorder the mask to make it compatible with the flash attention kernel.

    We vectorize KV cache read to improve DMA utilization. However, the layout
    that maximizes DMA bandwidth changes the order tokens are consumed.
    
    The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
    tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
    each step the engine consumes a column (rather than a row) of B_P_SIZE
    tokens. Therefore, the tokens are visited in a strided way.

    To make sure mask matches the order tokens are consumed, we need to properly
    transpose mask.
    """
    total_query_len, total_seq_len = mask.shape
    context_kv_len = total_seq_len - total_query_len

    B_P_SIZE = 128
    assert (LARGE_TILE_SZ
            >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
    num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
    tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
    if tiled_block_size > 1:
        # Mask reordering is needed when tiled_block_size > 1
        device = mask.device
        mask = mask.cpu()
        context_mask = mask[:, :context_kv_len]
        context_mask = context_mask.view(
            total_query_len,
            context_kv_len // LARGE_TILE_SZ,
            num_tiled_blocks // B_P_SIZE,
            B_P_SIZE,
            tiled_block_size,
        )
        context_mask = context_mask.transpose(3, 4).reshape(
            total_query_len, context_kv_len)
        new_mask = mask[:, context_kv_len:]
        return torch.concat([context_mask, new_mask], dim=1).to(device)
    else:
        return mask


def flash_attn_varlen_nkifunc(
    query,
    key,
    value,
    kv_cache,
    block_table,
    attn_mask,
    n_kv_head=None,
    head_size=None,
    LARGE_TILE_SZ=2048,
    mixed_precision=True,
):
    """
    Compute flash paged attention for variable length sequences.

    This function is a wrapper around the flash attention NKI kernel. It takes
    in the following arguments:
      - query: (1, n_heads, d, seq_q)
      - key:   (1, n_kv_heads, d, seq_k)
      - value: (1, n_kv_heads, seq_v, d)
      - kv_cache:   (2, n_blocks, n_kv_heads, block_size, d)
      - block_tables: (n_active_blocks, )
      - attn_mask: (seq_q, n_active_blocks * block_size + seq_q)

    Notes:
      - attn_mask must be reordered outside using `reorder_context_mask`
      - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) 
        for better DMA throughput
    """
    if n_kv_head is None:
        n_kv_head = kv_cache.shape[2]
    assert kv_cache.shape[0] == 2
    assert kv_cache.shape[2] == n_kv_head
    if head_size is None:
        head_size = kv_cache.shape[-1]

    kwargs = dict(
        query=query,
        key=key,
        value=value,
        kv_cache=kv_cache,
        block_tables=block_table,
        mask=attn_mask,
        softmax_scale=1.0 / (head_size**0.5),
        mixed_precision=mixed_precision,
        LARGE_TILE_SZ=LARGE_TILE_SZ,
    )

    o = flash_paged_attention[1, n_kv_head](**kwargs)
    return o


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
) -> None:
    """
    Writes key-value pairs to the KV cache at specified positions.

    Args:
        key (torch.Tensor): Key tensor with shape
            (num_tokens, n_kv_head, d_head)
        value (torch.Tensor): Value tensor with shape 
            (num_tokens, n_kv_head, d_head)
        kv_cache (torch.Tensor): Key/value cache tensor with shape 
            (2, num_blocks, n_kv_head, block_size, d_head)
        slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
            with shape (num_tokens)

    Returns:
        None: Updates the kv_cache tensor in-place
    """
    block_size = kv_cache.size(3)
    n_kv_head = key.size(1)

    # Calculate indices with explicit floor division
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_offsets = slot_mapping % block_size

    # Create the head indices tensor
    head_indices = torch.arange(n_kv_head, device=key.device)

    # Update caches using index_put_
    kv_cache.index_put_(
        (torch.tensor([0], device=key.device), block_indices[:, None],
         head_indices[None, :], block_offsets[:, None]), key)

    kv_cache.index_put_(
        (torch.tensor([1], device=key.device), block_indices[:, None],
         head_indices[None, :], block_offsets[:, None]), value)
