import random
import unittest

import torch

from sglang.srt.layers.attention.triton_ops.decode_attention import (
    decode_attention_fwd_grouped,
)
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
    decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
from sglang.test.test_utils import CustomTestCase


class TestTritonAttentionMLA(CustomTestCase):

    def _set_all_seeds(self, seed):
        """Set all random seeds for reproducibility."""
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def setUp(self):
        # Set seeds before each test method
        self._set_all_seeds(42)

    def preprocess_kv_cache(self, kv_cache, kv_lora_rank):
        latent_cache = kv_cache
        v_input = latent_cache[..., :kv_lora_rank]
        v_input = v_input.contiguous().unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., :kv_lora_rank] = v_input

        return k_input, v_input

    def input_helper(
        self,
        B,
        H,
        S,
        kv_lora_rank,
        rotary_dim,
        qk_rope_head_dim,
        num_kv_splits,
        dtype,
        device,
        rope_base=10,
        rope_max_seq_len=16384,
        rope_scaling=1.0,
        is_neox_style=False,
    ):
        q = torch.randn(
            B, H, kv_lora_rank + qk_rope_head_dim, device=device, dtype=dtype
        )
        kv_cache = torch.randn(
            B * S, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device
        )
        kv_indptr = torch.arange(B + 1, device=device) * S
        kv_indices = torch.arange(B * S, device=device)
        attn_logits = torch.empty(
            B, H, num_kv_splits, kv_lora_rank + 1, dtype=dtype, device=device
        )
        rotary_emb = DeepseekScalingRotaryEmbedding(
            qk_rope_head_dim,
            rotary_dim,
            rope_max_seq_len,
            rope_base,
            is_neox_style,
            rope_scaling,
            q.dtype,
            device="cpu",
        ).cuda()
        positions = torch.tensor([S], device=device).unsqueeze(0).repeat(B, 1)

        return kv_indptr, kv_indices, q, kv_cache, attn_logits, rotary_emb, positions

    def ref_compute_full_fwd(
        self,
        q,
        k_input,
        v_input,
        kv_lora_rank,
        kv_indptr,
        kv_indices,
        num_kv_splits,
        sm_scale,
        logit_cap,
        rotary_emb,
        positions,
        use_rope,
        device="cuda",
    ):

        B, H = q.shape[0], q.shape[1]
        S = kv_indptr[1].item()
        qk_rope_head_dim = k_input.shape[-1] - kv_lora_rank

        q_input = torch.empty(B, H, kv_lora_rank + qk_rope_head_dim, dtype=q.dtype).to(
            device
        )
        q_nope_out, q_pe = q.split([kv_lora_rank, qk_rope_head_dim], dim=-1)
        k_pe_t = k_input.view(B, 1, S, -1)[:, :, -1:, kv_lora_rank:]

        if use_rope:
            q_pe, k_pe_t = rotary_emb(positions, q_pe.unsqueeze(2), k_pe_t)
            q_pe = q_pe.squeeze()

        k_input.view(B, 1, S, -1)[:, :, -1:, kv_lora_rank:] = k_pe_t

        q_input[..., :kv_lora_rank] = q_nope_out
        q_input[..., kv_lora_rank:] = q_pe

        B, H = q_input.shape[0], q_input.shape[1]
        kv_lora_rank = v_input.shape[-1]
        device = q_input.device

        attn_logits = torch.empty(
            B, H, num_kv_splits, kv_lora_rank + 1, dtype=q_input.dtype, device=device
        )
        o = torch.empty(B, H, kv_lora_rank, dtype=q_input.dtype, device=device)

        decode_attention_fwd_grouped(
            q_input,
            k_input,
            v_input,
            o,
            kv_indptr,
            kv_indices,
            attn_logits,
            num_kv_splits,
            sm_scale,
            logit_cap,
        )

        return attn_logits, o, k_pe_t.squeeze()

    def _test_rocm_fused_mla_kernel(
        self,
        B,
        H,
        S,
        kv_lora_rank,
        qk_rope_head_dim,
        rotary_dim,
        dtype,
        use_rope,
        is_neox_style,
        num_kv_splits=2,
        sm_scale=1.0,
        logit_cap=0.0,
        device="cuda",
    ):
        kv_indptr, kv_indices, q, kv_cache, attn_logits, rotary_emb, positions = (
            self.input_helper(
                B,
                H,
                S,
                kv_lora_rank,
                rotary_dim,
                qk_rope_head_dim,
                num_kv_splits,
                dtype,
                device=device,
                is_neox_style=is_neox_style,
            )
        )

        k_input, v_input = self.preprocess_kv_cache(kv_cache, kv_lora_rank)
        k_pe_tokens = torch.empty(
            B, qk_rope_head_dim, dtype=kv_cache.dtype, device=device
        )
        tri_o = torch.empty(B, H, kv_lora_rank, dtype=kv_cache.dtype, device=device)

        decode_attention_fwd_grouped_rope(
            q,
            k_input,
            v_input,
            tri_o,
            kv_indptr,
            kv_indices,
            k_pe_tokens if use_rope else None,
            kv_lora_rank,
            rotary_dim if use_rope else None,
            rotary_emb.cos_sin_cache if use_rope else None,
            positions if use_rope else None,
            attn_logits,
            num_kv_splits,
            sm_scale,
            logit_cap,
            use_rope,
            is_neox_style,
        )

        tri_logits = attn_logits

        # reference
        ref_logits, ref_o, ref_k_pe_tokens = self.ref_compute_full_fwd(
            q,
            k_input,
            v_input,
            kv_lora_rank,
            kv_indptr,
            kv_indices,
            num_kv_splits,
            sm_scale,
            logit_cap,
            rotary_emb,
            positions,
            use_rope,
            device="cuda",
        )

        if use_rope:
            torch.testing.assert_close(
                ref_k_pe_tokens, k_pe_tokens.squeeze(), atol=1e-2, rtol=1e-2
            )
        torch.testing.assert_close(ref_logits, tri_logits, atol=1e-2, rtol=1e-2)
        torch.testing.assert_close(ref_o, tri_o, atol=1e-2, rtol=1e-2)

    def test_grouped_rocm_fused_mla(self):
        configs = [
            (1, 128, 2048, 512, 64, 64),
            (1, 128, 2048, 512, 128, 64),
            (1, 128, 2048, 512, 127, 64),
            (1, 128, 2050, 512, 127, 64),
            (1, 128, 2050, 512, 128, 64),
            (8, 128, 2048, 512, 64, 64),
            (8, 128, 2048, 512, 128, 64),
            (8, 128, 2048, 512, 127, 64),
            (8, 128, 2050, 512, 127, 64),
            (8, 128, 2050, 512, 128, 64),
        ]
        dtypes = [torch.bfloat16, torch.float32]
        use_rope_list = [True, False]
        is_neox_style_list = [True, False]

        for B, H, S, kv_lora_rank, qk_rope_head_dim, rotary_dim in configs:
            for dtype in dtypes:
                for use_rope in use_rope_list:
                    for is_neox_style in is_neox_style_list:
                        self._test_rocm_fused_mla_kernel(
                            B,
                            H,
                            S,
                            kv_lora_rank,
                            qk_rope_head_dim,
                            rotary_dim,
                            dtype,
                            use_rope,
                            is_neox_style,
                        )


if __name__ == "__main__":
    unittest.main()
