# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union

import torch

from megatron.core import parallel_state, tensor_parallel
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
    get_default_model_comm_pgs,
    MoEAuxLossAutoScaler,
    compute_and_register_gvo_loss_router,
    compute_and_register_hvo_loss_router,
    save_to_aux_losses_tracker,
    update_expert_coupling_stats,
)
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.token_dispatcher import (
    MoEAllGatherTokenDispatcher,
    MoEAlltoAllTokenDispatcher,
    MoEFlexTokenDispatcher,
    MoETokenDispatcher,
)
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig

try:
    import transformer_engine as te  # pylint: disable=unused-import

    from megatron.core.extensions.transformer_engine import te_checkpoint

    HAVE_TE = True
except ImportError:
    HAVE_TE = False


@dataclass
class MoESubmodules:
    """MoE Layer Submodule spec"""

    experts: Union[ModuleSpec, type] = None
    shared_experts: Union[ModuleSpec, type] = None


class BaseMoELayer(MegatronModule, ABC):
    """Base class for a mixture of experts layer.

    Args:
        config (TransformerConfig): Configuration object for the transformer model.
    """

    def __init__(
        self,
        config: TransformerConfig,
        layer_number: Optional[int] = None,
        model_comm_pgs: Optional[ModelCommProcessGroups] = None,
    ):
        super(BaseMoELayer, self).__init__(config)
        self.config = config
        self.layer_number = layer_number
        self.ep_group = model_comm_pgs.ep
        # use model_comm_pgs.expt_tp_group as tensor parallel group in this module.
        self.attn_tp_group = model_comm_pgs.tp
        ep_size = self.ep_group.size()
        ep_rank = self.ep_group.rank()
        assert ep_size > 0, "Expected non-negative expert parallel size"

        assert self.config.num_moe_experts % ep_size == 0
        self.num_local_experts = self.config.num_moe_experts // ep_size
        local_expert_indices_offset = ep_rank * self.num_local_experts

        self.use_shared_expert = self.config.moe_shared_expert_intermediate_size is not None
        self.shared_expert_overlap = self.config.moe_shared_expert_overlap

        self.local_expert_indices = [
            local_expert_indices_offset + i for i in range(self.num_local_experts)
        ]
        assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
        self.router: TopKRouter = None
        self.experts = None
        self.shared_experts = None
        self.token_dispatcher: Optional[MoETokenDispatcher] = None
        self.layer_number = layer_number

    @abstractmethod
    def forward(self, hidden_states):
        """Forward method for the MoE layer."""
        pass

    def set_layer_number(self, layer_number: int):
        """Set the layer number for the MoE layer."""
        self.layer_number = layer_number
        self.router.set_layer_number(layer_number)


class MoELayer(BaseMoELayer):
    """Mixture of Experts layer.

    This layer implements a Mixture of Experts model, where each token is routed to a
    subset of experts. This implementation supports different token dispatching
    strategies such as All-to-All and All-Gather.
    """

    def __init__(
        self,
        config: TransformerConfig,
        submodules: Optional[MoESubmodules] = None,
        layer_number: Optional[int] = None,
        model_comm_pgs: Optional[ModelCommProcessGroups] = None,
    ):
        self.submodules = submodules
        # TODO(Hepteract): delete the usage of the global parallel_state.
        # Initialize process groups with the global parallel_state.
        if model_comm_pgs is None:
            model_comm_pgs = get_default_model_comm_pgs()
        super(MoELayer, self).__init__(
            config=config, layer_number=layer_number, model_comm_pgs=model_comm_pgs
        )
        self.moe_layer_recompute = (
            config.recompute_granularity == 'selective' and "moe" in config.recompute_modules
        )

        # Initialize router
        self.router = TopKRouter(config=self.config, model_comm_pgs=model_comm_pgs)

        # Initialize token dispatcher
        if config.moe_token_dispatcher_type == "allgather":
            self.token_dispatcher = MoEAllGatherTokenDispatcher(
                self.num_local_experts,
                self.local_expert_indices,
                config=self.config,
                model_comm_pgs=model_comm_pgs,
            )
        elif config.moe_token_dispatcher_type == "alltoall":
            self.token_dispatcher = MoEAlltoAllTokenDispatcher(
                self.num_local_experts,
                self.local_expert_indices,
                config=self.config,
                model_comm_pgs=model_comm_pgs,
            )
        elif config.moe_token_dispatcher_type == "flex":
            self.token_dispatcher = MoEFlexTokenDispatcher(
                self.num_local_experts,
                self.local_expert_indices,
                config=self.config,
                model_comm_pgs=model_comm_pgs,
            )
        else:
            raise ValueError(
                f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
            )

        # Initialize experts
        self.experts = build_module(
            self.submodules.experts,
            self.num_local_experts,
            self.config,
            model_comm_pgs=model_comm_pgs,
        )

        # Initialize shared experts
        if self.use_shared_expert:
            self.shared_experts = build_module(
                self.submodules.shared_experts, config=self.config, model_comm_pgs=model_comm_pgs
            )
            if self.shared_expert_overlap:
                self.token_dispatcher.set_shared_experts(self.shared_experts)

    def router_and_preprocess(self, hidden_states: torch.Tensor):
        """Compute and preprocess token routing for dispatch.

        This method uses the router to determine which experts to send each token to,
        producing routing probabilities and a mapping. It then preprocesses the
        hidden states and probabilities for the token dispatcher. The original
        hidden states are returned as a residual connection.
        """
        residual = hidden_states
        probs, routing_map = self.router(hidden_states)
        if self.config.moe_expert_coupling_analysis and not self.training:
            update_expert_coupling_stats(self.layer_number, routing_map, self.config)
        # ---------------- GVO Loss (computed prior to dispatch) -----------------
        gvo_coeff = getattr(self.config, 'moe_gvo_loss_coeff', 0.0)
        iter_threshold_gvo = getattr(self.config, 'moe_gvo_loss_iter_threshold', 0)
        iter_warmup_gvo = getattr(self.config, 'moe_gvo_loss_iter_warmup', 0)
        try:
            from megatron.training import get_args
            args = get_args()
            current_iter = getattr(args, 'curr_iteration', getattr(args, 'iteration', 0))
        except ImportError:
            current_iter = 0
        if iter_threshold_gvo and current_iter < iter_threshold_gvo:
            gvo_coeff = 0
        elif iter_warmup_gvo > 0 and current_iter < iter_threshold_gvo + iter_warmup_gvo:
            # Linear warmup: coefficient increases from 0 to full value over warmup iterations
            warmup_progress = (current_iter - iter_threshold_gvo) / iter_warmup_gvo
            gvo_coeff = gvo_coeff * max(0.0, min(1.0, warmup_progress))
        if gvo_coeff != 0.0 and self.training and torch.is_grad_enabled():
            # Local subset of experts for this rank
            local_map = routing_map[
                :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
            ].contiguous()
            expert_w2 = None
            try:
                # GroupedMLP case: fetch W2 per local expert
                if hasattr(self.experts, 'weight2'):
                    expert_w2 = self.experts.weight2.view(
                        self.num_local_experts, -1, self.config.hidden_size
                    )
            except Exception:
                expert_w2 = None

            if expert_w2 is not None and local_map.any():
                gvo_loss, target_layer = compute_and_register_gvo_loss_router(
                    self.layer_number,
                    hidden_states.view(-1, hidden_states.shape[-1]),
                    local_map.view(-1, local_map.shape[-1]),
                    expert_w2,
                    coeff=gvo_coeff,
                    sequence_partition_group=None,
                )
                if gvo_loss is not None:
                    num_layers = self.config.num_layers
                    if self.config.mtp_num_layers is not None:
                        num_layers += self.config.mtp_num_layers
                    save_to_aux_losses_tracker(
                        "gvo_loss",
                        gvo_loss / gvo_coeff,
                        target_layer,
                        num_layers,
                    )
                    if self.config.calculate_per_token_loss:
                        hidden_states = MoEAuxLossAutoScaler.apply(
                            hidden_states, gvo_loss * hidden_states.shape[0]
                        )
                    else:
                        hidden_states = MoEAuxLossAutoScaler.apply(hidden_states, gvo_loss)
        # ---------------- RVO Loss (computed prior to dispatch) -----------------
        rvo_coeff = getattr(self.config, 'moe_rvo_loss_coeff', 0.0)
        if rvo_coeff != 0.0 and self.training and torch.is_grad_enabled():
            local_map = routing_map[
                :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
            ].contiguous()
            expert_w1 = None
            expert_w2 = None
            use_swiglu = bool(getattr(self.config, 'gated_linear_unit', False))
            try:
                # GroupedMLP case: fetch W1 and W2 per local expert
                if hasattr(self.experts, 'weight1') and hasattr(self.experts, 'weight2'):
                    # weight1: [H, E_local*D_fc1_tp] and weight2: [E_local*D_ffn_tp, H]
                    expert_w1 = self.experts.weight1.view(
                        self.config.hidden_size, self.num_local_experts, -1
                    ).permute(1, 0, 2).contiguous()  # [E_local, H, D_fc1_tp]
                    expert_w2 = self.experts.weight2.view(
                        self.num_local_experts, -1, self.config.hidden_size
                    )  # [E_local, D_ffn_tp, H]
            except Exception:
                expert_w1 = None
                expert_w2 = None

            if expert_w1 is not None and expert_w2 is not None and local_map.any():
                from megatron.core.transformer.moe.moe_utils import compute_and_register_rvo_loss_router

                rvo_loss, target_layer = compute_and_register_rvo_loss_router(
                    self.layer_number,
                    hidden_states.view(-1, hidden_states.shape[-1]),
                    local_map.view(-1, local_map.shape[-1]),
                    expert_w1,
                    expert_w2,
                    coeff=rvo_coeff,
                    use_swiglu=use_swiglu,
                    sequence_partition_group=None,
                )
                if rvo_loss is not None:
                    num_layers = self.config.num_layers
                    if self.config.mtp_num_layers is not None:
                        num_layers += self.config.mtp_num_layers
                    save_to_aux_losses_tracker(
                        "rvo_loss",
                        rvo_loss / rvo_coeff,
                        target_layer,
                        num_layers,
                    )
                    if self.config.calculate_per_token_loss:
                        hidden_states = MoEAuxLossAutoScaler.apply(
                            hidden_states, rvo_loss * hidden_states.shape[0]
                        )
                    else:
                        hidden_states = MoEAuxLossAutoScaler.apply(hidden_states, rvo_loss)
        # ---------------- HVO Loss (computed prior to dispatch) -----------------
        hvo_coeff = getattr(self.config, 'moe_hvo_loss_coeff', 0.0)
        iter_threshold_hvo = getattr(self.config, 'moe_hvo_loss_iter_threshold', 0)
        iter_warmup_hvo = getattr(self.config, 'moe_hvo_loss_iter_warmup', 0)
        try:
            from megatron.training import get_args
            args = get_args()
            current_iter = getattr(args, 'curr_iteration', getattr(args, 'iteration', 0))
        except ImportError:
            current_iter = 0
        if iter_threshold_hvo and current_iter < iter_threshold_hvo:
            hvo_coeff = 0
        elif iter_warmup_hvo > 0 and current_iter < iter_threshold_hvo + iter_warmup_hvo:
            warmup_progress = (current_iter - iter_threshold_hvo) / iter_warmup_hvo
            hvo_coeff = hvo_coeff * max(0.0, min(1.0, warmup_progress))
        if hvo_coeff != 0.0 and self.training and torch.is_grad_enabled():
            local_map = routing_map[
                :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
            ].contiguous()
            expert_w1 = None
            try:
                if hasattr(self.experts, 'weight1'):
                    expert_w1 = self.experts.weight1.view(
                        self.config.hidden_size, self.num_local_experts, -1
                    ).permute(1, 0, 2).contiguous()  # [E_local, H, D_fc1_tp]
            except Exception:
                expert_w1 = None

            if expert_w1 is not None and local_map.any():
                hvo_loss, target_layer = compute_and_register_hvo_loss_router(
                    self.layer_number,
                    hidden_states.view(-1, hidden_states.shape[-1]),
                    local_map.view(-1, local_map.shape[-1]),
                    expert_w1,
                    coeff=hvo_coeff,
                    sequence_partition_group=None,
                )
                if hvo_loss is not None:
                    num_layers = self.config.num_layers
                    if self.config.mtp_num_layers is not None:
                        num_layers += self.config.mtp_num_layers
                    save_to_aux_losses_tracker(
                        "hvo_loss",
                        hvo_loss / hvo_coeff,
                        target_layer,
                        num_layers,
                    )
                    if self.config.calculate_per_token_loss:
                        hidden_states = MoEAuxLossAutoScaler.apply(
                            hidden_states, hvo_loss * hidden_states.shape[0]
                        )
                    else:
                        hidden_states = MoEAuxLossAutoScaler.apply(hidden_states, hvo_loss)
        hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
            hidden_states, routing_map, probs
        )
        return hidden_states, probs, residual

    def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
        """Dispatches tokens to assigned expert ranks via communication.
        This method performs the actual communication (e.g., All-to-All) to distribute
        tokens and their associated probabilities to the devices hosting their assigned
        experts.
        """
        return self.token_dispatcher.token_dispatch(hidden_states, probs)

    def experts_compute(
        self, hidden_states: torch.Tensor, probs: torch.Tensor, residual: torch.Tensor
    ):
        """Computes the output of the experts on the dispatched tokens.

        This method first post-processes the dispatched input to get permuted tokens
        for each expert. It then passes the tokens through the local experts.
        If a shared expert is configured and not overlapped with communication,
        it is also applied. The output from the experts is preprocessed for the
        combine step.
        """
        shared_expert_output = None
        if self.use_shared_expert and not self.shared_expert_overlap:
            # Compute the shared expert separately when not overlapped with communication.
            shared_expert_output = self.shared_experts(residual)
        dispatched_input, tokens_per_expert, permuted_probs = (
            self.token_dispatcher.dispatch_postprocess(hidden_states, probs)
        )

        expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert, permuted_probs)
        assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}"
        output = self.token_dispatcher.combine_preprocess(expert_output)

        return output, shared_expert_output, mlp_bias

    def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Tensor]):
        """Combines expert outputs via communication and adds shared expert output.

        This method uses the token dispatcher to combine the outputs from different
        experts (e.g., via an All-to-All communication). It then adds the output
        from the shared expert if it exists.
        """
        output = self.token_dispatcher.token_combine(output)
        output = self.token_dispatcher.combine_postprocess(output)
        if shared_expert_output is not None:
            output = output + shared_expert_output
        return output

    def forward(self, hidden_states: torch.Tensor):
        """Forward pass for the MoE layer.

        The forward pass comprises four main steps:
        1. Routing & Preprocessing: Route tokens to the assigned experts and prepare for dispatch.
        2. Dispatch: Tokens are sent to the expert devices using communication collectives.
        3. Expert Computation: Experts process the dispatched tokens.
        4. Combine: The outputs from the experts are combined and returned.

        Args:
            hidden_states (torch.Tensor): The input tensor to the MoE layer.

        Returns:
            A tuple containing the output tensor and the MLP bias, if any.
        """
        if self.training and self.attn_tp_group.size() > 1 and not self.config.sequence_parallel:
            raise ValueError(
                "During training, performance may degrade if MoE and tensor parallelism"
                "are enabled without also enabling sequence parallelism."
            )

        # MoE forward: route -> dispatch -> compute -> combine
        def custom_forward(hidden_states):
            hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
            dispatched_input, probs = self.dispatch(hidden_states, probs)
            output, shared_expert_output, mlp_bias = self.experts_compute(
                dispatched_input, probs, residual
            )
            output = self.combine(output, shared_expert_output)
            return output, mlp_bias

        if self.moe_layer_recompute:
            if self.config.fp8:
                output, mlp_bias = te_checkpoint(
                    custom_forward,
                    False,
                    tensor_parallel.random.get_cuda_rng_tracker,
                    parallel_state.get_tensor_model_parallel_group(),
                    hidden_states,
                )
            else:
                output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
        else:
            output, mlp_bias = custom_forward(hidden_states)

        return output, mlp_bias

    def backward_dw(self):
        """Compute weight gradients for experts and shared experts."""
        self.experts.backward_dw()
        if self.use_shared_expert and not self.shared_expert_overlap:
            self.shared_experts.backward_dw()
