import itertools
import math

import cudnn
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
from flashinfer import BatchDecodeWithPagedKVCacheWrapper

from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.utils import should_use_tensor_core


def benchmark_forward(
    fn,
    *inputs,
    repeats=10,
    amp=False,
    amp_dtype=torch.float16,
    **kwinputs,
):
    def amp_wrapper(*inputs, **kwinputs):
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
            fn(*inputs, **kwinputs)

    t = benchmark.Timer(
        stmt="fn_amp(*inputs, **kwinputs)",
        globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
        num_threads=torch.get_num_threads(),
    )
    m = t.timeit(repeats)
    return t, m


def time_fwd(func, *args, **kwargs):
    time_f = benchmark_forward(func, *args, **kwargs)
    return time_f[1].mean * 1e6


def decode_attention_sglang(
    q,
    kv_data,
    batch_size,
    kv_len,
    head_num_q,
    head_num_kv,
    head_dim,
    num_kv_splits,
    warmup=10,
):

    k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
    v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
    o = torch.empty_like(q)
    total_tokens = batch_size * kv_len
    req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
    b_req_idx = torch.arange(0, batch_size).to(0).int()
    b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
    max_len_in_batch = kv_len
    sm_scale = 1.0 / (head_dim**0.5)

    attn_logits = torch.empty(
        (batch_size, head_num_q, num_kv_splits, head_dim + 1),
        dtype=torch.float32,
        device="cuda",
    )

    for _ in range(warmup):
        decode_attention_fwd(
            q,
            k_buffer,
            v_buffer,
            o,
            req_to_token,
            b_req_idx,
            b_seq_len,
            attn_logits,
            num_kv_splits,
            sm_scale,
        )

    f = time_fwd(
        decode_attention_fwd,
        q,
        k_buffer,
        v_buffer,
        o,
        req_to_token,
        b_req_idx,
        b_seq_len,
        attn_logits,
        num_kv_splits,
        sm_scale,
    )

    return f, o


def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
    workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
    use_tensor_cores = should_use_tensor_core(
        kv_cache_dtype=dtype,
        num_attention_heads=head_num_q,
        num_kv_heads=head_num_kv,
    )
    flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
    )

    class FlashinferAttention(torch.autograd.Function):
        @staticmethod
        def forward(
            ctx,
            q,
            kv_data,
            batch_size,
            kv_len,
            head_num_q,
            head_num_kv,
            head_dim,
            dtype,
            warmup=10,
        ):
            total_tokens = batch_size * kv_len
            kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
            kv_indices = torch.arange(0, total_tokens).to(0).int()
            kv_last_page_len = torch.full(
                (batch_size,), 1, dtype=torch.int32, device="cuda"
            )

            flashinfer_decode_wrapper.end_forward()
            flashinfer_decode_wrapper.begin_forward(
                kv_indptr,
                kv_indices,
                kv_last_page_len,
                head_num_q,
                head_num_kv,
                head_dim,
                1,
                pos_encoding_mode="NONE",
                data_type=dtype,
            )

            for _ in range(warmup):
                o = flashinfer_decode_wrapper.forward(
                    q.contiguous().view(-1, head_num_q, head_dim), kv_data
                )

            f = time_fwd(
                flashinfer_decode_wrapper.forward,
                q.contiguous().view(-1, head_num_q, head_dim),
                kv_data,
            )

            return f, o

    return FlashinferAttention


def convert_to_cudnn_type(torch_type):
    if torch_type == torch.float16:
        return cudnn.data_type.HALF
    elif torch_type == torch.bfloat16:
        return cudnn.data_type.BFLOAT16
    elif torch_type == torch.float32:
        return cudnn.data_type.FLOAT
    elif torch_type == torch.int32:
        return cudnn.data_type.INT32
    elif torch_type == torch.int64:
        return cudnn.data_type.INT64
    else:
        raise ValueError("Unsupported tensor data type.")


def decode_attention_cudnn(
    q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
):
    # Prepare data: continuous q,k,v
    dims_q = (batch_size, head_num_q, 1, head_dim)
    strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
    q_gpu = q.as_strided(dims_q, strides_q)
    o_gpu = (
        torch.empty(batch_size * head_num_q * head_dim)
        .half()
        .cuda()
        .as_strided(dims_q, strides_q)
    )

    dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
    strides_kv = (
        kv_len * head_num_kv * head_dim,
        head_dim,
        head_num_kv * head_dim,
        1,
    )
    k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
    v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)

    seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
    seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
    attn_scale = 1.0 / (head_dim**0.5)

    # Prepare data: paged k,v
    block_size = 1
    blocks_per_batch = math.ceil(kv_len / block_size)
    # [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
    container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
    container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
    page_table_k_gpu = (
        torch.linspace(
            0,
            batch_size * blocks_per_batch - 1,
            batch_size * blocks_per_batch,
            device="cuda",
            dtype=torch.int32,
        )
        .reshape(blocks_per_batch, 1, batch_size, 1)
        .transpose(0, 2)
    )
    page_table_v_gpu = page_table_k_gpu.clone()

    graph = cudnn.pygraph(
        io_data_type=convert_to_cudnn_type(dtype),
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
    )

    q = graph.tensor_like(q_gpu)
    container_k = graph.tensor_like(container_k_gpu)
    container_v = graph.tensor_like(container_v_gpu)
    page_table_k = graph.tensor_like(page_table_k_gpu)
    page_table_v = graph.tensor_like(page_table_v_gpu)

    seq_len_q = graph.tensor_like(seq_len_q_gpu)
    seq_len_kv = graph.tensor_like(seq_len_kv_gpu)

    o, _ = graph.sdpa(
        name="sdpa",
        q=q,
        k=container_k,  # Container K: non contiguous container with K blocks
        v=container_v,  # Container V: non contiguous container with V blocks
        is_inference=True,
        attn_scale=attn_scale,
        use_causal_mask=False,
        use_padding_mask=True,
        seq_len_q=seq_len_q,
        seq_len_kv=seq_len_kv,
        paged_attention_k_table=page_table_k,  # Page Table K: Tensor containing offsets to the container with K blocks
        paged_attention_v_table=page_table_v,  # Page Table V: Tensor containing offsets to the container with V blocks
        paged_attention_max_seq_len_kv=kv_len,  # The maximum sequence length for K caches (this is optional, but recommended)
    )

    o.set_output(True).set_dim(dims_q).set_stride(strides_q)

    graph.validate()
    graph.build_operation_graph()
    graph.create_execution_plans([cudnn.heur_mode.A])
    graph.check_support()
    graph.build_plans()

    workspace = torch.empty(
        graph.get_workspace_size(), device="cuda", dtype=torch.uint8
    )

    variant_pack = {
        q: q_gpu,
        container_k: container_k_gpu,
        container_v: container_v_gpu,
        page_table_k: page_table_k_gpu,
        page_table_v: page_table_v_gpu,
        seq_len_q: seq_len_q_gpu,
        seq_len_kv: seq_len_kv_gpu,
        o: o_gpu,
    }

    for _ in range(warmup):
        graph.execute(variant_pack, workspace)

    f = time_fwd(
        graph.execute,
        variant_pack,
        workspace,
    )

    return f, o_gpu.squeeze(dim=2)


def calculate_diff():

    dtype = torch.float16
    batch_size = 64
    kv_len = 4096
    head_num_q = 64
    head_num_kv = 8
    head_dim = 128

    q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
    kv_data = (
        torch.randn(
            batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
        ),
        torch.randn(
            batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
        ),
    )

    _, output_sglang = decode_attention_sglang(
        q,
        kv_data,
        batch_size,
        kv_len,
        head_num_q,
        head_num_kv,
        head_dim,
        num_kv_splits=8,
    )

    attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
    _, output_flashinfer = attn_flashinfer(
        q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
    )

    _, output_cudnn = decode_attention_cudnn(
        q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
    )

    print(f"SGLang output={output_sglang}")
    print(f"FlashInfer output={output_flashinfer}")
    print(f"cuDNN output={output_cudnn}")
    if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
        print("✅ SGLang[Triton] and FlashInfer match")
    else:
        print("❌ SGLang[Triton] and FlashInfer differ")

    if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
        print("✅ SGLang[Triton] and cuDNN match")
    else:
        print("❌ SGLang[Triton] and cuDNN differ")


if __name__ == "__main__":
    calculate_diff()

    head_dim = 128
    dtype = torch.float16
    batch_size_range = [2**i for i in range(0, 8, 2)]
    kv_len_range = [2**i for i in range(6, 13, 1)]
    configs = list(itertools.product(batch_size_range, kv_len_range))

    for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
        attn_flashinfer = decode_attention_flashinfer(
            dtype, head_num_q, head_num_kv
        ).apply
        for batch_size, kv_len in configs:
            q = torch.randn(
                batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
            )
            kv_data = (
                torch.randn(
                    batch_size * kv_len,
                    head_num_kv,
                    head_dim,
                    dtype=dtype,
                    device="cuda",
                ),
                torch.randn(
                    batch_size * kv_len,
                    head_num_kv,
                    head_dim,
                    dtype=dtype,
                    device="cuda",
                ),
            )
            us_cudnn, output_cudnn = decode_attention_cudnn(
                q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
            )
            us_sglang, output_sglang = decode_attention_sglang(
                q,
                kv_data,
                batch_size,
                kv_len,
                head_num_q,
                head_num_kv,
                head_dim,
                num_kv_splits=8,
            )
            us_flashinfer, _ = attn_flashinfer(
                q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
            )
            print(
                head_num_q,
                "  ",
                head_num_kv,
                "  ",
                batch_size,
                "  ",
                kv_len,
                "  ",
                us_cudnn,
                "  ",
                us_sglang,
                "  ",
                us_flashinfer,
            )
