import copy
import math
import warnings
from typing import Sequence, Union, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
                      build_norm_layer)
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
                        to_2tuple)
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import build_feedforward_network
from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK


@FEEDFORWARD_NETWORK.register_module()
class SparseMoE(BaseModule):
    def __init__(self,
                 num_experts,
                 top_k,
                 jitter_noise=0.,
                 load_balancing_loss_coef=0.02,
                 embed_dims=256,
                 feedforward_channels=1024,
                 num_fcs=2,
                 act_cfg=dict(type='ReLU', inplace=True),
                 ffn_drop=0.,
                 dropout_layer=None,
                 add_identity=True,
                 init_cfg=None,
                 **kwrags):
        super().__init__(init_cfg)
        if top_k > num_experts:
            raise ValueError(f'top_k must be less than or equal to num_experts, got {num_experts=} and {top_k=}')
        
        self.num_experts = num_experts
        self.top_k = top_k
        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.add_identity = add_identity

        self.router = Linear(embed_dims, num_experts, bias=False)

        # we will add dropout and identity outside the experts
        expert_cfg = {
            'type': 'FFN',
            'embed_dims': embed_dims,
            'feedforward_channels': feedforward_channels,
            'num_fcs': num_fcs,
            'act_cfg': act_cfg,
            'ffn_drop': ffn_drop,
            'dropout_layer': None,
            'add_identity': False,
            'init_cfg': init_cfg,
        }
        self.experts = ModuleList([build_feedforward_network(expert_cfg) for _ in range(num_experts)])
        self.dropout_layer = build_dropout(
            dropout_layer) if dropout_layer else torch.nn.Identity()

        self.jitter_noise = jitter_noise
        self.load_balancing_loss_coef = load_balancing_loss_coef
        self.router_logits = None

    def forward(self, hidden_states, identity=None):
        if identity is None:
            identity = hidden_states
        sequence_length, batch_size, hidden_dim = hidden_states.shape
        if self.training and self.jitter_noise > 0:
            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.router(hidden_states) # (sequence_length * batch_size, n_experts)

        if self.training:
            self.router_logits = router_logits

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (sequence_length * batch_size, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(sequence_length, batch_size, hidden_dim)

        if not self.add_identity:
            return self.dropout_layer(final_hidden_states)
        return identity + self.dropout_layer(final_hidden_states)

    def reset_router_logits(self):
        self.router_logits = None

    # def compute_load_balancing_loss(self):
    #     if not self.training:
    #         raise RuntimeError('The model is in training mode, cannot compute load balancing loss')
    #     if self.router_logits is None:
    #         raise RuntimeError('Run forward function first to get the router logits')

    #     router_logits = (self.router_logits,)
    #     loss = load_balancing_loss_func(router_logits, self.num_experts, self.top_k) * self.load_balancing_loss_coef
    #     self.reset_router_logits()
    #     return loss

def load_balancing_loss_func(
    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
    num_experts: Optional[int] = None,
    top_k=2,
    attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
    r"""
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.

    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        gate_logits:
            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
            shape [batch_size X sequence_length, num_experts].
        num_experts:
            Number of experts
        top_k:
            The number of experts to route per-token, can be also interpreted as the `top-k` routing
            parameter.
        attention_mask (`torch.Tensor`, *optional*):
            The attention_mask used in forward function
            shape [batch_size X sequence_length] if not None.

    Returns:
        The auxiliary loss.
    """
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0

    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)

        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
        expert_attention_mask = (
            attention_mask[None, :, :, None, None]
            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
            .reshape(-1, top_k, num_experts)
            .to(compute_device)
        )

        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
            expert_attention_mask, dim=0
        )

        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
        router_per_expert_attention_mask = (
            attention_mask[None, :, :, None]
            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
            .reshape(-1, num_experts)
            .to(compute_device)
        )

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
            router_per_expert_attention_mask, dim=0
        )

    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    return overall_loss * num_experts
