import torch
from typing import Dict

try:
    import awq_ext  # with CUDA kernels

    AWQ_INSTALLED = True
except:
    AWQ_INSTALLED = False


class FusedSparseMoeBlock(torch.nn.Module):
    def __init__(
        self,
        top_k,
        gate,
        ws,
        w2s,
    ):
        super().__init__()
        self.gate = gate
        self.top_k = top_k
        self.ws = ws
        self.w2s = w2s

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        final_hidden_states = apply_moe_weights(
            self.ws,
            self.w2s,
            hidden_states,
            router_logits,
            self.top_k,
            renormalize=True,
        )

        return final_hidden_states.view(batch_size, sequence_length, hidden_dim)



class FusedDeepseekMoEBlock(torch.nn.Module):
    def __init__(
        self,
        top_k,
        gate,
        ws,
        w2s,
        shared_experts, 
    ):
        super().__init__()
        self.gate = gate
        self.top_k = top_k
        self.ws = ws
        self.w2s = w2s
        self.shared_experts = shared_experts
        
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        identity = hidden_states

        # router_logits: (batch * sequence_length, n_experts)
        _, router_logits, _ = self.gate(hidden_states)
        hidden_states = hidden_states.view(-1, hidden_dim)

        final_hidden_states = apply_moe_weights(
            self.ws,
            self.w2s,
            hidden_states,
            router_logits,
            self.top_k,
            renormalize=True,
        )
        # add shared part. 
        if self.shared_experts is not None:
            final_hidden_states = final_hidden_states + self.shared_experts(identity)
        return final_hidden_states.view(batch_size, sequence_length, hidden_dim)




def apply_moe_weights(
    w1: Dict[str, torch.Tensor],
    w2: Dict[str, torch.Tensor],
    x: torch.Tensor,
    gating_output: torch.Tensor,
    topk: int,
    renormalize: bool,
) -> torch.Tensor:
    topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
    (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size(
        topk_ids, 16, w1.qweight.shape[0]
    )

    x = x.view(x.shape[0], 1, *x.shape[1:])

    gate_up = awq_ext.grouped_gemm_forward(
        x,
        w1.qweight,
        w1.scales,
        w1.qzeros,
        topk_weights,
        sorted_token_ids,
        expert_ids,
        num_tokens_post_padded,
        False,
        8,
    )

    out = torch.empty(
        (gate_up.shape[:-1] + (gate_up.shape[-1] // 2,)), dtype=x.dtype, device=x.device
    )
    awq_ext.silu_and_mul(out, gate_up)

    out = awq_ext.grouped_gemm_forward(
        out,
        w2.qweight,
        w2.scales,
        w2.qzeros,
        topk_weights,
        sorted_token_ids,
        expert_ids,
        num_tokens_post_padded,
        True,
        8,
    )

    return torch.sum(out, dim=1)



def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: int):
    """
    Aligns the token distribution across experts to be compatible with block size for matrix multiplication.

    Parameters:
    - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.
    - block_size: The block size used in block matrix multiplication.
    - num_experts: The total number of experts.

    Returns:
    - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.
    - expert_ids: A tensor indicating the assigned expert index for each block.
    - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.

    This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
    Padding ensures that during block matrix multiplication, the dimensions align correctly.

    Example:
    Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:
    - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.
    - As block_size is 4, we pad 1 token for each expert.
    - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
    - Then append padding tokens [12, 12, 12, 12] for each block.
    - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
        Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
    - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
    """
    sorted_ids = torch.empty(
        (topk_ids.numel() + num_experts * (block_size - 1),),
        dtype=torch.int32,
        device=topk_ids.device,
    )
    expert_ids = torch.empty(
        (topk_ids.numel() + num_experts,), dtype=torch.int32, device=topk_ids.device
    )
    sorted_ids.fill_(topk_ids.numel())
    num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
    awq_ext.moe_alig_block_size(
        topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
    )
    return sorted_ids, expert_ids, num_tokens_post_pad


def fused_topk(
    gating_output: torch.Tensor,
    topk: int,
    renormalize: bool,
):
    """Compute top-k indice and weights from gating logits

    Args:
        gating_output (torch.Tensor): The output of the gating operation (before softmax).
        topk (int): The number of top-k experts to select.
        renormalize (bool): If True, renormalize the top-k weights to sum to 1.
    """
    M = gating_output.shape[0]
    if torch.version.hip is not None:
        # The MoE kernels are not yet supported on ROCm.
        routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
        topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
    else:
        topk_weights = torch.empty(
            M, topk, dtype=torch.float32, device=gating_output.device
        )
        topk_ids = torch.empty(M, topk, dtype=torch.int32, device=gating_output.device)
        token_expert_indicies = torch.empty(
            M, topk, dtype=torch.int32, device=gating_output.device
        )
        awq_ext.topk_softmax(
            topk_weights,
            topk_ids,
            token_expert_indicies,
            gating_output.float(),  # TODO(woosuk): Optimize this.
        )
        del token_expert_indicies  # Not used. Will be used in the future.
    if renormalize:
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
    return topk_weights, topk_ids
