import typing
from collections.abc import Callable
from collections import defaultdict
from typing import Any, Dict, TYPE_CHECKING, Optional, Tuple, List

import torch
import copy

from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F

if TYPE_CHECKING:
    Base = Module[Tensor]
else:
    Base = Module


MOE_TOP_K = 2
Constant = 2


class CopyExpert(torch.nn.Module):
    def __init__(self, expert):
        super(CopyExpert, self).__init__()
        pass

    def forward(self, inputs):
        return inputs


class ZeroExpert(torch.nn.Module):
    def __init__(self, expert):
        super(ZeroExpert, self).__init__()
        pass

    def forward(self, inputs):
        return torch.zeros_like(inputs).to(inputs.dtype).to(inputs.device)


class ConstantExpert(torch.nn.Module):
    def __init__(self, expert):
        super(ConstantExpert, self).__init__()
        self.constant = torch.nn.Parameter(
            torch.empty((expert.hidden_size)))
        torch.nn.init.normal_(self.constant)

        self.wg = torch.nn.Linear(expert.hidden_size, 2, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, inputs):
        # print(inputs.size())
        weight = self.wg(inputs)
        weight = self.softmax(weight)
        return torch.einsum('b,bd->bd', [weight[:, 0].type_as(inputs), inputs]) + torch.einsum(
                'b,d->bd', [weight[:, 1].type_as(inputs), self.constant.type_as(inputs)])


def gating(logits: Tensor, moe_use_mixtral_gating=False, moe_use_logits_norm=False, moe_gate_norm_std=1.0) -> Dict[int, List[Tuple[int, float]]]:
    # gates shape [num_tokens, num_experts]
    num_experts = logits.size(1)
    if moe_use_mixtral_gating:
        if moe_use_logits_norm:
            target_std = moe_gate_norm_std
            logits_std = logits.std(dim=1, keepdim=True)
            logits = logits / (logits_std / target_std)
        gates, indices = torch.topk(logits, k=MOE_TOP_K, dim=1)
        gates = F.softmax(gates, dim=1)
    else:
        target_std = moe_gate_norm_std
        if moe_use_logits_norm:
            logits_std = logits.std(dim=1, keepdim=True)
            gates = F.softmax(logits / (logits_std / target_std), dim=1)
        else:
            gates = F.softmax(logits, dim=1)
        # gates shape [num_tokens, MOE_TOP_K]
        # indices shape [num_tokens, MOE_TOP_K]
        gates, indices = torch.topk(gates, k=MOE_TOP_K, dim=1)
        gates = torch.where(indices==(num_experts-1), torch.zeros_like(gates).to(gates.dtype).to(gates.device), gates)
        gates /= torch.sum(gates, dim=1, keepdim=True)

    expert_info = defaultdict(list)
    for expert_id in range(num_experts):
        token_ids, score_ids = torch.nonzero(indices == expert_id, as_tuple=True)
        expert_info[expert_id] = [token_ids, gates[token_ids, score_ids]]

    return expert_info


class Router(Module):
    def __init__(self,
                 model_dim: int,
                 num_experts: int,
                 moe_use_mixtral_gating: bool,
                 moe_2layer_gate: bool,
                 moe_use_logits_norm: bool,
                 moe_gate_norm_std: float,
                 ) -> None:
        super().__init__()

        if moe_2layer_gate:
            self.wg = torch.nn.Sequential(
                torch.nn.Linear(model_dim, num_experts * 8, bias=False).float(),
                torch.nn.Tanh(),
                torch.nn.Linear(num_experts * 8, num_experts, bias=False).float()).float()
        else:
            self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()

        self.gate_map = torch.nn.Linear(num_experts, num_experts, bias=False)

        self.gate = gating
        self.moe_use_mixtral_gating = moe_use_mixtral_gating
        self.moe_use_logits_norm = moe_use_logits_norm
        self.moe_gate_norm_std = moe_gate_norm_std

    def forward(self, input: torch.Tensor, gate_residual=None) -> Dict[int, List[Tuple[int, float]]]:
        if isinstance(self.wg, torch.nn.Linear):
            if self.wg.weight.dtype != torch.float32:
                self.wg = self.wg.float()
                setattr(self.wg.weight, 'router', True)
        else:
            if self.wg[0].weight.dtype != torch.float32:
                self.wg = self.wg.float()
                setattr(self.wg[0].weight, "router", True)
                setattr(self.wg[2].weight, "router", True)
        input_fp32 = input.float()
        logits = self.wg(input_fp32)

        if gate_residual is not None:
            gate_residual = self.gate_map(gate_residual.to(self.gate_map.weight.dtype))
            logits += gate_residual

        gate_output = self.gate(logits, self.moe_use_mixtral_gating, self.moe_use_logits_norm, self.moe_gate_norm_std)

        return gate_output, logits


class Experts(torch.nn.Module):
    def __init__(self, expert, num_local_experts=1):
        super(Experts, self).__init__()

        self.experts = torch.nn.ModuleList(
            [copy.deepcopy(expert) for _ in range(num_local_experts - 2 - Constant)] +
            [ConstantExpert(expert) for _ in range(Constant)] +
            [CopyExpert(expert), ZeroExpert(expert)])

    def forward(self, inputs):
        raise NotImplementedError


class MOELayer(Base):
    def __init__(self,
                 gate: Module,
                 experts: Module,
                 ep_size,
                 num_local_experts: int,
                 moe_use_mixtral_gating: bool,
                 moe_feature_no_mul_topk: bool) -> None:
        super().__init__()
        self.gate = gate
        self.experts = experts
        self.ep_size = ep_size
        self.num_local_experts = num_local_experts
        self.moe_use_mixtral_gating = moe_use_mixtral_gating
        self.moe_feature_no_mul_topk = moe_feature_no_mul_topk

    def forward(self, *input: Tensor, gate_residual=None, **kwargs: Any) -> Tensor:
        d_model = input[0].shape[-1]
        reshaped_input = input[0].reshape(-1, d_model)
        output = torch.zeros_like(reshaped_input)
        expert_info, gate_residual = self.gate(reshaped_input, gate_residual)
        if not (self.moe_use_mixtral_gating or self.moe_feature_no_mul_topk):
            reshaped_input *= MOE_TOP_K
        for expert, token_indices_and_gates in expert_info.items():
            indices, gating = token_indices_and_gates
            gating = gating.unsqueeze(-1)
            tokens = reshaped_input.index_select(dim=0, index=indices)
            expert_output = self.experts.experts[expert](tokens)
            expert_output *= gating
            output.index_add_(dim=0, index=indices, source=expert_output)
        output = output.reshape(input[0].shape)

        return output, gate_residual


class MOE(torch.nn.Module):
    def __init__(self,
                 hidden_size,
                 expert,
                 num_experts=1,
                 ep_size=1,
                 moe_use_mixtral_gating=False,
                 moe_2layer_gate=True,
                 moe_use_logits_norm=False,
                 moe_gate_norm_std=1.0,
                 moe_feature_no_mul_topk=False):
        super(MOE, self).__init__()

        self.ep_size = ep_size
        self.num_experts = num_experts
        self.num_local_experts = num_experts // self.ep_size
        self.moe_use_mixtral_gating = moe_use_mixtral_gating
        self.moe_2layer_gate = moe_2layer_gate
        self.moe_use_logits_norm = moe_use_logits_norm
        self.moe_gate_norm_std = moe_gate_norm_std
        self.moe_feature_no_mul_topk = moe_feature_no_mul_topk

        experts = Experts(expert, self.num_local_experts)
        self.moe = MOELayer(Router(hidden_size,
                                   num_experts,
                                   self.moe_use_mixtral_gating,
                                   self.moe_2layer_gate,
                                   self.moe_use_logits_norm,
                                   self.moe_gate_norm_std),
                            experts,
                            self.ep_size,
                            self.num_local_experts,
                            self.moe_use_mixtral_gating,
                            self.moe_feature_no_mul_topk,
                            )

    def forward(self, hidden_states, used_token=None, gate_residual=None):
        output, gate_residual = self.moe(hidden_states, used_token, gate_residual=gate_residual)
        return output, gate_residual