import itertools
import random
import unittest
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

from sglang.srt.layers.moe.ep_moe.kernels import (
    grouped_gemm_triton,
    post_reorder_triton_kernel,
    pre_reorder_triton_kernel,
    run_moe_ep_preproess,
    silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import select_experts
from sglang.test.test_utils import CustomTestCase


# For test
def ep_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    # ep config
    num_experts: int = 256,
    fp8_dtype: torch.types = torch.float8_e4m3fn,
    num_experts_per_partition: int = 128,
    start_expert_id: int = 0,
    end_expert_id: int = 127,
    use_grouped_topk: bool = False,
    num_expert_group: Optional[int] = None,
    topk_group: Optional[int] = None,
    custom_routing_function: Optional[Callable] = None,
    use_fp8_w8a8: bool = False,
    w1_scale_inv: Optional[torch.Tensor] = None,
    w2_scale_inv: Optional[torch.Tensor] = None,
    block_shape: Optional[List[int]] = None,
):
    use_blockwise_fp8 = block_shape is not None
    topk_weights, topk_ids = select_experts(
        hidden_states=hidden_states,
        router_logits=router_logits,
        top_k=top_k,
        use_grouped_topk=use_grouped_topk,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        # correction_bias=correction_bias, #skip this in test
        custom_routing_function=custom_routing_function,
    )

    reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)

    gateup_input = torch.empty(
        (int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
        device=hidden_states.device,
        dtype=(
            fp8_dtype
            if (use_fp8_w8a8 and not use_blockwise_fp8)
            else hidden_states.dtype
        ),
    )

    if use_fp8_w8a8 and not use_blockwise_fp8:
        max_value = (
            torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
        )
        w1_input_scale = max_value / torch.finfo(fp8_dtype).max
    else:
        w1_input_scale = None

    # PreReorder
    pre_reorder_triton_kernel[(hidden_states.shape[0],)](
        hidden_states,
        gateup_input,
        src2dst,
        topk_ids,
        w1_input_scale,
        start_expert_id,
        end_expert_id,
        top_k,
        hidden_states.shape[1],
        BLOCK_SIZE=512,
    )

    seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
    weight_indices_cur_rank = torch.arange(
        0,
        num_experts_per_partition,
        device=hidden_states.device,
        dtype=torch.int64,
    )

    # GroupGemm-0
    gateup_output = torch.empty(
        gateup_input.shape[0],
        w1.shape[1],
        device=hidden_states.device,
        dtype=hidden_states.dtype,
    )

    gateup_output = grouped_gemm_triton(
        a=gateup_input,
        b=w1,
        c=gateup_output,
        batch_size=num_experts_per_partition,
        weight_column_major=True,
        seg_indptr=seg_indptr_cur_rank,
        weight_indices=weight_indices_cur_rank,
        use_fp8_w8a8=use_fp8_w8a8,
        scale_a=w1_input_scale,
        scale_b=w1_scale_inv,
        block_shape=block_shape,
    )

    # Act
    down_input = torch.empty(
        gateup_output.shape[0],
        gateup_output.shape[1] // 2,
        device=gateup_output.device,
        dtype=(
            fp8_dtype
            if (use_fp8_w8a8 and not use_blockwise_fp8)
            else hidden_states.dtype
        ),
    )
    if use_fp8_w8a8 and not use_blockwise_fp8:
        w2_input_scale = torch.ones(
            num_experts_per_partition,
            dtype=torch.float32,
            device=hidden_states.device,
        )
    else:
        w2_input_scale = None

    silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
        gateup_output,
        down_input,
        gateup_output.shape[1],
        reorder_topk_ids,
        w2_input_scale,
        start_expert_id,
        end_expert_id,
        BLOCK_SIZE=512,
    )

    # GroupGemm-1
    down_output = torch.empty(
        down_input.shape[0],
        w2.shape[1],
        device=hidden_states.device,
        dtype=hidden_states.dtype,
    )

    down_output = grouped_gemm_triton(
        a=down_input,
        b=w2,
        c=down_output,
        batch_size=num_experts_per_partition,
        weight_column_major=True,
        seg_indptr=seg_indptr_cur_rank,
        weight_indices=weight_indices_cur_rank,
        use_fp8_w8a8=use_fp8_w8a8,
        scale_a=w2_input_scale,
        scale_b=w2_scale_inv,
        block_shape=block_shape,
    )

    # PostReorder
    output = torch.empty_like(hidden_states)
    post_reorder_triton_kernel[(hidden_states.size(0),)](
        down_output,
        output,
        src2dst,
        topk_ids,
        topk_weights,
        start_expert_id,
        end_expert_id,
        top_k,
        hidden_states.size(1),
        BLOCK_SIZE=512,
    )
    return output


# test util
def block_dequant(
    x_q_block: torch.Tensor,
    x_s: torch.Tensor,
    block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """This function converts block-wise quantization to tensor-wise quantization.
    The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
    and the block size.
    The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
    Note only float8 is supported for now.
    """

    # process 3D tensor
    if x_q_block.dim() == 3:
        batch_size = x_q_block.size(0)
        return torch.stack(
            [block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
        )

    block_n, block_k = block_size[0], block_size[1]
    n, k = x_q_block.shape
    n_tiles = (n + block_n - 1) // block_n
    k_tiles = (k + block_k - 1) // block_k
    assert n_tiles == x_s.shape[0]
    assert k_tiles == x_s.shape[1]

    x_dq_block = x_q_block.to(torch.float32)

    x_dq_block_tiles = [
        [
            x_dq_block[
                j * block_n : min((j + 1) * block_n, n),
                i * block_k : min((i + 1) * block_k, k),
            ]
            for i in range(k_tiles)
        ]
        for j in range(n_tiles)
    ]

    for i in range(k_tiles):
        for j in range(n_tiles):
            x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]

    return x_dq_block


class TestW8A8BlockFP8EPMoE(CustomTestCase):
    DTYPES = [torch.half, torch.bfloat16]
    M = [1, 222, 1024, 2048]
    N = [128, 1024, 2048]
    K = [256, 4096, 5120]
    E = [8, 16]
    ep_size = [2, 4]
    TOP_KS = [2, 4]
    BLOCK_SIZE = [[128, 128]]
    SEEDS = [0]

    @classmethod
    def setUpClass(cls):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA is not available")
        torch.set_default_device("cuda")

    def _w8a8_block_fp8_ep_moe(
        self, M, N, K, E, ep_size, topk, block_size, dtype, seed
    ):
        torch.manual_seed(seed)
        random.seed(seed)
        # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
        factor_for_scale = 1e-2
        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

        a = torch.randn((M, K), dtype=dtype) / 10

        w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
        w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

        w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
        w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

        block_n, block_k = block_size[0], block_size[1]
        n_tiles_w1 = (2 * N + block_n - 1) // block_n
        n_tiles_w2 = (K + block_n - 1) // block_n
        k_tiles_w1 = (K + block_k - 1) // block_k
        k_tiles_w2 = (N + block_k - 1) // block_k

        w1_s = (
            torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
            * factor_for_scale
        )
        w2_s = (
            torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
            * factor_for_scale
        )

        w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
        w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)

        score = torch.randn((M, E), dtype=dtype)
        num_experts_per_partition = E // ep_size
        cur_rank = random.randint(0, ep_size - 1)
        start_id = cur_rank * num_experts_per_partition
        end_id = start_id + num_experts_per_partition - 1

        with torch.inference_mode():
            out = ep_moe(
                hidden_states=a,
                w1=w1,
                w2=w2,
                router_logits=score,
                top_k=topk,
                renormalize=False,
                use_fp8_w8a8=True,
                w1_scale_inv=w1_s,
                w2_scale_inv=w2_s,
                block_shape=block_size,
                num_experts=E,
                num_experts_per_partition=num_experts_per_partition,
                start_expert_id=start_id,
                end_expert_id=end_id,
            )
            ref_out = ep_moe(
                hidden_states=a,
                w1=w1_ref,
                w2=w2_ref,
                router_logits=score,
                top_k=topk,
                renormalize=False,
                use_fp8_w8a8=False,
                w1_scale_inv=None,
                w2_scale_inv=None,
                block_shape=None,
                num_experts=E,
                num_experts_per_partition=num_experts_per_partition,
                start_expert_id=start_id,
                end_expert_id=end_id,
            )
        self.assertTrue(
            torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
            / (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
            < 0.06
        )

    def test_w8a8_block_fp8_ep_moe(self):
        for params in itertools.product(
            self.M,
            self.N,
            self.K,
            self.E,
            self.ep_size,
            self.TOP_KS,
            self.BLOCK_SIZE,
            self.DTYPES,
            self.SEEDS,
        ):
            with self.subTest(
                M=params[0],
                N=params[1],
                K=params[2],
                E=params[3],
                ep_size=params[4],
                topk=params[5],
                block_size=params[6],
                dtype=params[7],
                seed=params[8],
            ):
                self._w8a8_block_fp8_ep_moe(*params)
            torch.cuda.empty_cache()


if __name__ == "__main__":
    unittest.main(verbosity=2)
