# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import pytest
import torch
import torch.nn.functional as F


class BlockDiagonalCausalFromBottomRightMask:

    @staticmethod
    def _from_seqlens(query_lens, seq_lens, block_size=None):
        from torch import logical_and, logical_or

        contexted = block_size is None
        context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
        n_queries = sum(query_lens)
        num_seqs = len(query_lens)
        if contexted:
            key_lens_blockaligned = seq_lens
        else:
            n_blocks_per_seq = (context_lens + block_size - 1) // block_size
            offset_per_seq = n_blocks_per_seq * block_size
            key_lens_blockaligned = offset_per_seq[:num_seqs].tolist()
        n_keys = sum(key_lens_blockaligned)

        a = (torch.arange(n_queries).reshape(n_queries,
                                             1).expand(n_queries, n_keys))
        b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys)
        q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0)
        k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0)

        prior_mask = torch.zeros(n_queries, n_keys)
        new_masks: list[torch.Tensor] = []
        for seq_id in range(num_seqs):
            ri = q_cumsum[seq_id]
            ci = k_cumsum[seq_id]
            nr = query_lens[seq_id]

            if contexted:
                nc = seq_lens[seq_id]
                a_offset = ci + nc - ri - nr
                new_mask = (a + a_offset) >= b
            else:
                nc = context_lens[seq_id]
                a_offset = ci + nc - 1
                new_mask = a_offset >= b

            left_mask = b >= ci
            top_mask = a >= ri
            bottom_mask = a < (ri + nr)

            new_mask = logical_and(
                logical_and(logical_and(new_mask, left_mask), top_mask),
                bottom_mask,
            )
            prior_mask = logical_or(prior_mask, new_mask)
            new_masks = new_masks + [new_mask]
        return prior_mask

    @staticmethod
    def from_seqlens(query_lens, seq_lens, block_size=None):
        contexted = block_size is None
        if contexted:
            prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
                query_lens, seq_lens)
            active_mask = None
        else:
            prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
                query_lens, seq_lens, block_size)
            active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
                query_lens, query_lens)
        return prior_mask, active_mask


def ref_softmax(x: torch.Tensor,
                dim: int,
                mixed_precision=False,
                return_max_reduce=False):
    max_value = torch.amax(x, dim=dim, keepdims=True)
    exp = torch.exp(x - max_value)
    if mixed_precision:
        sum_value = torch.sum(exp.astype(torch.float32),
                              dim=dim,
                              keepdims=True).astype(x.dtype)
    else:
        sum_value = torch.sum(exp, dim=dim, keepdims=True)
    if return_max_reduce:
        return exp / sum_value, max_value, torch.reciprocal(sum_value)
    return exp / sum_value


def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
    return_max_reduce: Optional[bool] = False,
) -> torch.Tensor:
    scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float()
    if attn_mask is not None:
        masked_score = scaled_qk + attn_mask.float()
    if return_max_reduce:
        norm_score, cached_max, cached_sum_reciprocal = ref_softmax(
            masked_score, dim=-1, return_max_reduce=True)
    else:
        norm_score = ref_softmax(masked_score, dim=-1)
    out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
    if return_max_reduce:
        return (
            out,
            cached_max,
            cached_sum_reciprocal,
            norm_score,
            masked_score,
            scaled_qk,
        )
    else:
        return (out, )


def ref_context_attention(
    query,
    key,
    value,
    query_lens,
    seq_lens,
    head_size,
    num_queries_per_kv,
    return_max_reduce=False,
):
    scale = float(1.0 / (head_size**0.5))
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)

    attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
        query_lens, seq_lens)

    # convert binary mask to -inf values
    attn_mask = torch.logical_not(attn_mask)
    attn_mask = attn_mask.float() * -30000

    output, *debug_tensors = ref_masked_attention(
        query,
        key,
        value,
        scale,
        attn_mask,
        return_max_reduce=return_max_reduce,
    )

    output = output.unsqueeze(1)
    if return_max_reduce:
        cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
            debug_tensors)
        return (
            output,
            cached_max,
            cached_sum_reciprocal,
            lse,
            masked_score,
            scaled_qk,
        )
    else:
        return output


def sample_inputs(
    prefill_batch_size,
    decode_batch_size,
    min_query_len,
    max_query_len,
    min_ctx_len,
    max_ctx_len,
    block_size,
    num_heads,
    num_kv_heads,
    head_size,
    dtype,
):
    batch_size = prefill_batch_size + decode_batch_size
    max_model_len = (max_query_len + max_ctx_len) * 4
    max_block_per_request = max_model_len // block_size
    cache_size = (batch_size * max_block_per_request) + 2
    prefill_ctx_lens = torch.randint(min_ctx_len,
                                     max_ctx_len + 1, (prefill_batch_size, ),
                                     dtype=torch.long).tolist()
    decode_ctx_lens = torch.randint(min_ctx_len,
                                    max_ctx_len + 1, (decode_batch_size, ),
                                    dtype=torch.long).tolist()
    ctx_lens = prefill_ctx_lens + decode_ctx_lens
    query_lens = torch.randint(
        min_query_len,
        max_query_len + 1,
        (prefill_batch_size, ),
        dtype=torch.long,
    ).tolist() + [1 for _ in range(decode_batch_size)]
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]

    num_tokens = sum(query_lens)
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
    query.uniform_(-1, 1)
    torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
    kv.uniform_(-1, 1)
    key, value = kv.unbind(dim=1)

    k_cache = torch.zeros(cache_size,
                          block_size,
                          num_kv_heads,
                          head_size,
                          dtype=dtype)
    v_cache = torch.zeros(cache_size,
                          block_size,
                          num_kv_heads,
                          head_size,
                          dtype=dtype)
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    values = torch.arange(0, cache_size, dtype=torch.long)
    values = values[torch.randperm(cache_size)]
    block_table = values[:batch_size * max_block_per_request].view(
        batch_size, max_block_per_request)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
                                            dtype=torch.long),
                               dim=0)
    # copy kv to cache
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
                                                dtype=torch.long),
                                   dim=0)
    for i in range(batch_size):
        for j in range(query_lens[i]):
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
                                            j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
                                              b_ctx_len[i] + j])
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
            k_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             key[start_loc:end_loc])
            v_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             value[start_loc:end_loc])
            cur_ctx += block_size
            block_id += 1
    kv_cache = torch.stack([k_cache, v_cache])

    return (
        query,
        k,
        v,
        kv_cache,
        block_table,
        key,
        value,
        query_lens,
        seq_lens,
    )


def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
                            num_blocks):
    context_lens = seq_lens - query_lens
    blocks_per_seq = (context_lens + block_size - 1) // block_size
    num_seqs = len(seq_lens)
    active_blocks: list[int] = []
    for seq_id in range(num_seqs):
        active_blocks = (
            active_blocks +
            block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
    return F.pad(
        torch.tensor(active_blocks, dtype=torch.int32),
        (0, num_blocks - len(active_blocks)),
        "constant",
        0,
    )


@pytest.mark.parametrize(
    "prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision",
    [
        # Test minimal configurations (small block size)
        (1, 199, 1, 512, 4, 2, 8, False
         ),  # minimal block size, small dimensions
        (1, 199, 1, 512, 4, 2, 8, True),  # same with mixed precision

        # Test common/medium configurations
        (4, 12, 32, 2048, 32, 8, 64, False),  # common case, larger heads
        (4, 12, 32, 2048, 16, 4, 32,
         True),  # medium size, mixed precision, grouped-query attention (GQA)

        # Test large configurations
        (4, 12, 256, 8192, 8, 1, 128, False),  # large blocks, large head size
        (4, 12, 256, 8192, 64, 8, 64, True),  # large blocks, many heads

        # Test asymmetric configurations
        (2, 24, 64, 4096, 12, 4, 96, False),  # varied batch sizes
        (8, 8, 128, 2048, 24, 2, 48, True),  # balanced batches

        # Test edge cases
        (1, 128, 16, 1024, 4, 2, 16, False),  # large decode batch
        (16, 4, 8, 1024, 4, 2, 128, True),  # large prefill batch
        (4, 12, 32, 2048, 16, 1, 32, True),  # multi-head attention (MHA)
        (4, 12, 32, 2048, 16, 16, 32, True),  # multi-query attention (MQA)
    ])
@torch.inference_mode()
def test_contexted_kv_attention(
    monkeypatch: pytest.MonkeyPatch,
    prefill_batch_size: int,
    decode_batch_size: int,
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    block_size: int,
    large_tile_size,
    mixed_precision: bool,
) -> None:

    import torch_xla.core.xla_model as xm

    from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
                                                   reorder_context_mask)

    assert large_tile_size % block_size == 0

    device = xm.xla_device()

    compiler_flags_str = " ".join([
        "-O1",
        "--retry_failed_compilation",
    ])
    with monkeypatch.context() as m:
        m.setenv("NEURON_CC_FLAGS", compiler_flags_str)

        torch.manual_seed(0)
        torch.set_printoptions(sci_mode=False)
        torch.set_default_device("cpu")
        dtype = torch.float32

        min_ctx_len = 32
        max_ctx_len = 1024
        min_query_len = 16
        max_query_len = 512
        num_kv_heads = num_heads // num_queries_per_kv
        (
            query,
            k_active,
            v_active,
            kv_cache,
            block_table,
            key,
            value,
            query_lens,
            seq_lens,
        ) = sample_inputs(
            prefill_batch_size=prefill_batch_size,
            decode_batch_size=decode_batch_size,
            min_query_len=min_query_len,
            max_query_len=max_query_len,
            min_ctx_len=min_ctx_len,
            max_ctx_len=max_ctx_len,
            block_size=block_size,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
            dtype=dtype,
        )

        output_ref = ref_context_attention(
            query,
            key,
            value,
            query_lens,
            seq_lens,
            head_size,
            num_queries_per_kv,
            return_max_reduce=False,
        )

        # build neuron program
        B_P_SIZE = 128
        assert (large_tile_size >= B_P_SIZE
                ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"

        def ceil_div(a, b):
            return (a + b - 1) // b

        def pad_to_multiple(a, b):
            return ceil_div(a, b) * b

        def pad_to_next_power_of_2(a):
            assert a > 0
            return 2**int(a - 1).bit_length()

        # calculate input shapes
        max_num_queries = pad_to_next_power_of_2(sum(query_lens))
        context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
        num_active_blocks = ceil_div(context_lens, block_size).sum().item()
        num_active_blocks = pad_to_multiple(num_active_blocks,
                                            large_tile_size // block_size)
        context_kv_len = num_active_blocks * block_size
        assert (
            context_kv_len %
            large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"

        # pad QKV tensors
        pad_dims = (
            0,
            0,
            0,
            0,
            0,
            max_num_queries - query.shape[0],
        )
        query = F.pad(query, pad_dims, "constant", 0)
        k = F.pad(k_active, pad_dims, "constant", 0)
        v = F.pad(v_active, pad_dims, "constant", 0)

        # permute QKV tensors
        # query: (1, n_heads, d, seq_q)
        # key:   (1, n_kv_heads, d, seq_k)
        # value: (1, n_kv_heads, seq_v, d)
        query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
        k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
        v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
        kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous()

        # transform block table
        active_block_table = get_active_block_tables(
            block_table.cpu(),
            torch.tensor(query_lens).cpu(),
            torch.tensor(seq_lens).cpu(),
            block_size,
            num_active_blocks,
        )

        # Build attention masks
        prior_mask, active_mask = (
            BlockDiagonalCausalFromBottomRightMask.from_seqlens(
                query_lens, seq_lens, block_size=block_size))
        prior_mask_padded = F.pad(
            prior_mask,
            (
                0,
                context_kv_len - prior_mask.shape[1],
                0,
                max_num_queries - prior_mask.shape[0],
            ),
            "constant",
            0,
        ).bool()
        active_mask_padded = F.pad(
            active_mask,
            (
                0,
                max_num_queries - active_mask.shape[1],
                0,
                max_num_queries - active_mask.shape[0],
            ),
            "constant",
            0,
        ).bool()
        attn_mask = torch.concat([prior_mask_padded, active_mask_padded],
                                 dim=1)

        attn_mask = reorder_context_mask(attn_mask, large_tile_size,
                                         block_size)

        input_args = (
            query.to(device=device),
            k.to(device=device),
            v.to(device=device),
            kv_cache.to(device=device),
            active_block_table.to(device=device),
            attn_mask.to(device=device),
        )
        input_kwargs = dict(
            n_kv_head=num_kv_heads,
            head_size=head_size,
            mixed_precision=mixed_precision,
            LARGE_TILE_SZ=large_tile_size,
        )

        output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)

        num_actual_tokens = sum(query_lens)
        # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
        output_nki = output_nki.cpu().permute(0, 2, 1, 3)
        output_nki = output_nki[0, :num_actual_tokens, :, :]
        output_ref_padded = F.pad(
            output_ref,
            (0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
            "constant",
            0,
        )
        output_ref = output_ref_padded.transpose(
            0, 1)[0, :num_actual_tokens, :, :]

        torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0)
