# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import pytest
import torch

from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import (
    merge_attn_states as merge_attn_states_triton)
from vllm.platforms import current_platform


# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states_torch(
        output: torch.Tensor,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
        prefix_output: torch.Tensor,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
        prefix_lse: torch.Tensor,  # [NUM_HEADS, NUM_TOKENS]
        suffix_output: torch.Tensor,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
        suffix_lse: torch.Tensor,  # [NUM_HEADS, NUM_TOKENS]
        output_lse: Optional[torch.Tensor] = None,  # [NUM_HEADS, NUM_TOKENS]
):
    p_lse = prefix_lse
    s_lse = suffix_lse
    # inf -> -inf
    p_lse[p_lse == torch.inf] = -torch.inf
    s_lse[s_lse == torch.inf] = -torch.inf
    # max_lse [NUM_HEADS, NUM_TOKENS]
    max_lse = torch.maximum(p_lse, s_lse)
    p_lse = p_lse - max_lse
    s_lse = s_lse - max_lse
    p_lse_exp = torch.exp(p_lse)
    s_lse_exp = torch.exp(s_lse)
    out_se = (p_lse_exp + s_lse_exp)
    if output_lse is not None:
        output_lse = torch.log(out_se) + max_lse
    p_scale = p_lse_exp / out_se  # [NUM_HEADS, NUM_TOKENS]
    s_scale = s_lse_exp / out_se  # [NUM_HEADS, NUM_TOKENS]
    p_scale = torch.transpose(p_scale, 0,
                              1).unsqueeze(2)  # [NUM_TOKENS, NUM_HEADS, 1]
    s_scale = torch.transpose(s_scale, 0,
                              1).unsqueeze(2)  # [NUM_TOKENS, NUM_HEADS, 1]
    output = prefix_output * p_scale + suffix_output * s_scale
    return output, output_lse


NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536, 4096]
NUM_QUERY_HEADS = [4, 8, 16, 32, 48, 64]
HEAD_SIZES = [32, 48, 64, 96, 128, 256]
DTYPES = [torch.float32, torch.half, torch.bfloat16]

all_case_info: list[tuple] = []


def generate_markdown_table():
    global all_case_info
    table_header = ("| tokens | heads | headsize | dtype "
                    "| device | torch | triton | cuda | speedup |")
    table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |"

    def shortly_dtype(dtype: torch.dtype) -> str:
        return str(dtype).removeprefix("torch.")

    def shortly_device(device: str) -> str:
        return device.removeprefix("NVIDIA").strip()

    print(table_header)
    print(table_separator)
    for info in all_case_info:
        (num_tokens, num_heads, head_size, dtype, device,
         avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
         performance_improved) = info
        dtype = shortly_dtype(dtype)
        device = shortly_device(device)
        print(f"| {num_tokens} | {num_heads} | {head_size} "
              f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
              f"| {avg_time_triton_kernel:.5f}ms "
              f"| {avg_time_cuda_kernel:.5f}ms "
              f"| {performance_improved:.4f}x |")


@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
                           head_size: int, output_dtype: torch.dtype):
    if not current_platform.is_cuda():
        pytest.skip('Currently only support compare triton merge_attn_states '
                    'with custom cuda merge_attn_states kernel')

    NUM_TOKENS = num_tokens
    NUM_HEADS = num_query_heads
    HEAD_SIZE = head_size

    print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
          f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
          f"Device: {current_platform.get_device_name()}")

    # prefix_lse and suffix_lse contain inf and normal values
    prefix_lse = torch.randn(NUM_HEADS,
                             NUM_TOKENS,
                             dtype=torch.float32,
                             device="cuda")
    suffix_lse = torch.randn(NUM_HEADS,
                             NUM_TOKENS,
                             dtype=torch.float32,
                             device="cuda")

    # Generate boolean masks
    mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
    mask_suffix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
    # Ensure that the same position is not True at the same time
    combined_mask = torch.logical_and(mask_prefix, mask_suffix)
    mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
    mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)

    prefix_lse[mask_prefix] = float('inf')
    suffix_lse[mask_suffix] = float('inf')

    # Other input tensors (need to be initialized but
    # no actual calculation needed)
    output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
                         dtype=output_dtype,
                         device="cuda")
    output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS),
                             dtype=torch.float32,
                             device="cuda")
    prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
                                dtype=output_dtype,
                                device="cuda")
    suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
                                dtype=output_dtype,
                                device="cuda")

    warmup_times = 2
    repeat_times = 20

    output_torch = output.clone()
    output_lse_torch = output_lse.clone()
    total_time_torch_kernel = 0
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # 0. Run the Torch kernel
    prefix_lse_torch = prefix_lse.clone()
    suffix_lse_torch = suffix_lse.clone()
    for _ in range(warmup_times):
        output_torch, output_lse_torch = merge_attn_states_torch(
            output_torch, prefix_output, prefix_lse_torch, suffix_output,
            suffix_lse_torch, output_lse_torch)
    torch.cuda.synchronize()

    for _ in range(repeat_times):
        start.record()
        output_torch, output_lse_torch = merge_attn_states_torch(
            output_torch, prefix_output, prefix_lse_torch, suffix_output,
            suffix_lse_torch, output_lse_torch)
        end.record()
        torch.cuda.synchronize()
        total_time_torch_kernel += start.elapsed_time(end)

    avg_time_torch_kernel = total_time_torch_kernel / repeat_times

    # 1. Run the Triton kernel
    output_ref_triton = output.clone()
    output_lse_ref_triton = output_lse.clone()

    total_time_triton_kernel = 0
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for _ in range(warmup_times):
        merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
                                 suffix_output, suffix_lse,
                                 output_lse_ref_triton)
    torch.cuda.synchronize()

    for _ in range(repeat_times):
        start.record()
        merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
                                 suffix_output, suffix_lse,
                                 output_lse_ref_triton)
        end.record()
        torch.cuda.synchronize()
        total_time_triton_kernel += start.elapsed_time(end)

    avg_time_triton_kernel = total_time_triton_kernel / repeat_times

    # 2. Run the CUDA kernel
    total_time_cuda_kernel = 0
    output_cuda = output.clone()
    output_lse_cuda = output_lse.clone()

    for _ in range(warmup_times):
        merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
                               suffix_output, suffix_lse, output_lse_cuda)
    torch.cuda.synchronize()

    for _ in range(repeat_times):
        start.record()
        merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
                               suffix_output, suffix_lse, output_lse_cuda)
        end.record()
        torch.cuda.synchronize()
        total_time_cuda_kernel += start.elapsed_time(end)

    avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times

    # 3. Performance compare
    performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel
    print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
    print(f"Triton time: {avg_time_triton_kernel:.6f}ms")
    print(f"  CUDA time: {avg_time_cuda_kernel:.6f}ms, "
          f"Performance: {performance_improved:.5f}x")
    print("-" * 100)

    # 4. Correctness compare
    # Liger Kernel: Efficient Triton Kernels for LLM Training
    # https://arxiv.org/pdf/2410.10989, 3.3 Correctness
    # use rtol = 1e-2 for bfloat16.
    rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3

    def diff(a: torch.Tensor, b: torch.Tensor):
        max_diff = torch.max(torch.abs(a.float() - b.float()))
        return max_diff

    # Use Triton output as reference because we want to replace
    # the Triton kernel with custom CUDA kernel for merge attn
    # states operation.
    output_ref = output_ref_triton
    output_lse_ref = output_lse_ref_triton
    torch.testing.assert_close(output_cuda.float(),
                               output_ref.float(),
                               atol=1e-3,
                               rtol=rtol)
    print("Output all match, max abs diff:")
    print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
    print(f"  (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
    print(f"  (CUDA vs Triton): {diff(output_ref, output_cuda)}")
    print("-" * 100)

    torch.testing.assert_close(output_lse_cuda.float(),
                               output_lse_ref.float(),
                               atol=1e-3,
                               rtol=rtol)
    print("Output LSE all match, max abs diff:")
    print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
    print(f"  (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}")
    print(f"  (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}")
    print("-" * 100)

    print("All output values test passed! All inf values "
          "are correctly replaced with -inf.")
    print("-" * 100)

    device = current_platform.get_device_name()
    all_case_info.append(
        (NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device,
         avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
         performance_improved))
    if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) *
                              len(NUM_QUERY_HEADS) * len(DTYPES)):
        generate_markdown_table()
