# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

from abc import ABC, abstractmethod

import torch

from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP, SequentialMLPReLU
from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher
from megatron.core.transformer.moe.router import TopKRouter, ReLURouter, TopKExpertChoiceRouter, LoryRouter, HashRouter
from megatron.core.transformer.moe.token_dispatcher import (
    MoEAllGatherTokenDispatcher,
    MoEAlltoAllTokenDispatcher,
    MoEAlltoAllTokenDispatcherReLU,
    MoEAlltoAllTokenDispatcherExpertChoice,
)
from megatron.core.transformer.transformer_config import TransformerConfig


class BaseMoELayer(MegatronModule, ABC):
    """Base class for a mixture of experts layer.

    Args:
        config (TransformerConfig): Configuration object for the transformer model.
    """

    def __init__(self, config: TransformerConfig, layer_number: int = None):
        super(BaseMoELayer, self).__init__(config)
        self.config = config
        self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()
        assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size"

        if self.config.moe_extended_tp:
            self.num_local_experts = self.config.num_moe_experts
            local_expert_indices_offset = 0
        else:
            assert self.config.num_moe_experts % self.expert_parallel_size == 0
            self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
            local_expert_indices_offset = (
                parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
            )

        self.local_expert_indices = [
            local_expert_indices_offset + i for i in range(self.num_local_experts)
        ]
        assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
        self.router = None
        self.experts = None
        self.token_dispatcher = None
        self.layer_number = layer_number

    @abstractmethod
    def forward(self, hidden_states):
        """Forward method for the MoE layer."""
        pass

    def set_layer_number(self, layer_number: int):
        """Set the layer number for the MoE layer."""
        self.layer_number = layer_number
        self.router.set_layer_number(layer_number)

from megatron.core.transformer.mlp import MLP, NaiveMLP
class MoELayer(BaseMoELayer):
    """Mixture of experts Layer **currently only supports no token dropping**.

    Args:
        BaseMoELayer (MegatronModule): Base class for MoE layers
    """

    def __init__(
        self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
    ):
        self.submodules = submodules
        super(MoELayer, self).__init__(config=config, layer_number=layer_number)
        if self.config.moe_relu_routing:
            self.router = ReLURouter(config=self.config)
            assert config.moe_token_dispatcher_type == "alltoall"
        elif self.config.moe_routing_type == "ec":
            self.router = TopKExpertChoiceRouter(config=self.config)
        elif self.config.moe_routing_type == "lory":
            self.router = LoryRouter(config=self.config)
        elif self.config.moe_routing_type == "hash":
            self.router = HashRouter(config=self.config)
        else:
            self.router = TopKRouter(config=self.config)
        
        if self.config.moe_naive_mlp:
            self.experts = MLP(self.config, self.submodules)
        elif self.config.moe_routing_type == "lory":
            self.experts = NaiveMLP(self.config)
        elif self.config.moe_grouped_gemm:
            if isinstance(self.submodules, MLPSubmodules):
                self.experts = TEGroupedMLP(self.num_local_experts, self.config, self.submodules)
            else:
                self.experts = GroupedMLP(self.num_local_experts, self.config)
        else:
            assert isinstance(self.submodules, MLPSubmodules)
            self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
        if config.moe_token_dispatcher_type == "allgather":
            self.token_dispatcher = MoEAllGatherTokenDispatcher(
                self.num_local_experts, self.local_expert_indices, config=self.config
            )
        elif config.moe_token_dispatcher_type == "alltoall":
            if self.config.moe_relu_routing:
                self.token_dispatcher = MoEAlltoAllTokenDispatcherReLU(
                    self.num_local_experts, self.local_expert_indices, config=self.config
                )
            elif self.config.moe_routing_type == "ec":
                self.token_dispatcher = MoEAlltoAllTokenDispatcherExpertChoice(
                    self.num_local_experts, self.local_expert_indices, config=self.config
                )
            else:
                self.token_dispatcher = MoEAlltoAllTokenDispatcher(
                    self.num_local_experts, self.local_expert_indices, config=self.config
                )
        elif config.moe_token_dispatcher_type == "alltoall_seq":
            self.token_dispatcher = MoEAlltoAllSEQTokenDispatcher(
                self.num_local_experts, self.local_expert_indices, config=self.config
            )
        else:
            raise ValueError(
                f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
            )
        self.moe_layer_recompute = config.moe_layer_recompute

    def forward(self, hidden_states: torch.Tensor):
        if (
            self.training
            and self.config.tensor_model_parallel_size > 1
            and not self.config.sequence_parallel
        ):
            raise ValueError(
                "During training, performance may degrade if MoE and tensor parallelism"
                "are enabled without also enabling sequence parallelism."
            )

        # process MoE
        def custom_forward(hidden_states):
            if self.config.moe_naive_mlp:
                probs, indices = self.router(hidden_states)
                if not self.config.moe_relu_routing:
                    # probs shape: [S*B, K], indices shape: [S*B, K]
                    probs_expanded = torch.zeros(probs.shape[0], self.config.num_moe_experts, dtype=probs.dtype, device=probs.device)
                    probs_expanded.scatter_(1, indices, probs)
                    # expert_output, mlp_bias = self.experts.forward_with_probs(hidden_states, probs_expanded)
                    probs = probs_expanded
                # if self.config.moe_relu_routing:
                    # probs shape: [S*B, E]
                expert_output, mlp_bias = self.experts.forward_with_probs(hidden_states, probs)
                return expert_output, mlp_bias
            elif self.config.moe_routing_type == "lory":
                e = self.router(hidden_states)
                expert_output, mlp_bias = self.experts(hidden_states, e)
                return expert_output, mlp_bias
            # elif self.config.moe_routing_type == "hash":
                
            #     expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
            #     return expert_output, mlp_bias

            probs, indices = self.router(hidden_states)
            (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
                hidden_states, probs, indices
            )
            expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
            output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
            return output, mlp_bias

        if self.moe_layer_recompute:
            output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
        else:
            output, mlp_bias = custom_forward(hidden_states)

        return output, mlp_bias
