# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

from abc import ABC, abstractmethod
from functools import partial
from typing import Callable

import torch
import json
import math

from megatron.core import parallel_state
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
    MoEAuxLossAutoScaler,
    SoftTopKAlphaScheduler,
    save_to_aux_losses_tracker,
    save_to_confusion_tracker,
    sequence_load_balancing_loss_func,
    sinkhorn,
    switch_load_balancing_loss_func,
    domain_load_balancing_loss_func,
    topk_softmax_with_capacity,
    z_loss_func,
    base_metrics,
)
from megatron.core.transformer.transformer_config import TransformerConfig


class Router(ABC, MegatronModule):
    """Base Router class"""

    def __init__(self, config: TransformerConfig) -> None:
        """
        Initialize the Router module.

        Args:
            config (TransformerConfig): Configuration object for the Transformer model.
        """
        super().__init__(config)
        self.config = config
        self.num_experts = self.config.num_moe_experts
        self.moe_aux_loss_func = None
        self.layer_number = None
        self.moe_aux_loss_coeff = None
        self.moe_router_domain_loss_coeff = None

        # Initialize the gate weights.
        # TODO: Add support for GPU initialization, which requires updating the golden values.
        self.weight = torch.nn.Parameter(
            torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
        )
        if config.perform_initialization:
            config.init_method(self.weight)
        self.weight.data = self.weight.data.to(dtype=config.params_dtype)
        setattr(self.weight, 'sequence_parallel', config.sequence_parallel)

        if self.config.moe_router_soft_topk_learn_k:
            self.learnable_k_eta = torch.nn.Parameter(
                torch.empty((1), dtype=torch.float32)
            )

            if config.perform_initialization:
                config.init_method(self.learnable_k_eta)
            self.learnable_k_eta.data = self.learnable_k_eta.data.to(dtype=config.params_dtype)

        else:
            self.learnable_k_eta = None

    def gating(self, input: torch.Tensor):
        """Forward pass of the router gate.

        Args:
            input (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Logits tensor.
        """
        if self.weight.device.type == 'cpu':
            # move weights to GPU
            self.weight.data = self.weight.data.to(device=torch.cuda.current_device())
        # Convert to specified datatype for routing computation if enabled
        router_dtype = input.dtype
        if self.config.moe_router_dtype == 'fp32':
            router_dtype = torch.float32
        elif self.config.moe_router_dtype == 'fp64':
            router_dtype = torch.float64
        logits = torch.nn.functional.linear(input.to(router_dtype), self.weight.to(router_dtype))
        return logits

    @abstractmethod
    def routing(self, logits: torch.Tensor):
        """Routing function.

        Args:
            logits (torch.Tensor): Logits tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
            probabilities and mapping.
        """
        raise NotImplementedError("Routing function not implemented.")

    @abstractmethod
    def forward(self, input: torch.Tensor):
        """
        Forward pass of the router.

        Args:
            input (torch.Tensor): Input tensor.
        """
        raise NotImplementedError("Forward function not implemented.")

    def set_layer_number(self, layer_number: int):
        """Set the layer number for the router."""
        self.layer_number = layer_number

class TopKRouter(Router):
    """Route each token to the top-k experts."""

    def __init__(self, config: TransformerConfig) -> None:
        """Initialize the zero token dropping router.

        Args:
            config (TransformerConfig): The configuration for the transformer model.
        """
        super().__init__(config=config)
        self.topk = self.config.moe_router_topk
        self.routing_type = self.config.moe_router_load_balancing_type
        self.score_function = self.config.moe_router_score_function
        self.input_jitter = None

        if self.score_function == "soft-topk":
            self.soft_topk_scheduler = SoftTopKAlphaScheduler(self.config.moe_router_soft_topk_alpha, self.config.moe_router_soft_topk_schedule_iters)
            self.soft_topk_alpha = self.config.moe_router_soft_topk_alpha[0]
        else:
            self.soft_topk_scheduler = None
            self.soft_topk_alpha = None

        if self.config.moe_router_domain_topk:
            self.register_buffer(
                'expert_class_map',
                self.load_exper_class_map_from_config(config),
                persistent=True
            )
        else:
            self.expert_class_map = None

        self.tokens_distribution_metrics = self.config.moe_router_tokens_dist_metrics
        self.enable_expert_bias = self.config.moe_router_enable_expert_bias
        if self.enable_expert_bias or self.tokens_distribution_metrics:
            self.register_buffer(
                'local_tokens_per_expert',
                torch.zeros(self.config.num_moe_experts, dtype=torch.float32),
                persistent=False,
            )
        else:
            self.local_tokens_per_expert = None
        if self.enable_expert_bias:
            self.register_buffer(
                'expert_bias', torch.zeros(self.config.num_moe_experts, dtype=torch.float32)
            )
        else:
            self.expert_bias = None
    
    def set_layer_number(self, layer_number):
        super().set_layer_number(layer_number)

        if not isinstance(self.config.moe_aux_loss_coeff, list):
            self.moe_aux_loss_coeff = self.config.moe_aux_loss_coeff
        elif len(self.config.moe_aux_loss_coeff) == 1:
            self.moe_aux_loss_coeff = self.config.moe_aux_loss_coeff[0]
        else:
            self.moe_aux_loss_coeff = self.config.moe_aux_loss_coeff[layer_number - 1]

        if not isinstance(self.config.moe_router_domain_loss_coeff, list):
            self.moe_router_domain_loss_coeff = self.config.moe_router_domain_loss_coeff
        elif len(self.config.moe_router_domain_loss_coeff) == 1:
            self.moe_router_domain_loss_coeff = self.config.moe_router_domain_loss_coeff[0]
        else:
            self.moe_router_domain_loss_coeff = self.config.moe_router_domain_loss_coeff[layer_number - 1]

    def update_topk(self, topk):
        self.topk = topk

    def update_soft_topk_alpha(self):
        _iteration = parallel_state.get_iteration()
        current_iter = _iteration['curent_iteration']
        total_iters = _iteration['target_iterations']

        self.soft_topk_alpha = self.soft_topk_scheduler.step(current_iter, total_iters)

    def load_exper_class_map_from_config(self, config: TransformerConfig):
        with open(config.moe_router_domain_config_path, 'r') as file:
            mapping = json.load(file).values()
            assert len(mapping) == config.num_moe_experts, f"Mismatch between expected {len(mapping)} and actual expert count {config.num_moe_experts}"
            return torch.tensor(list(mapping), dtype=torch.long)

    def sinkhorn_load_balancing(self, logits: torch.Tensor):
        """Apply sinkhorn routing to the logits tensor.

        Args:
            logits (torch.Tensor): The logits tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
            probabilities and mask.
        """

        def _sinkhorn_activation(logits):
            if self.topk == 1:
                logits = torch.sigmoid(logits)
            else:  # k > 1
                logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
            return logits

        assert self.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
        if self.training:
            with torch.no_grad():
                norm_logits = sinkhorn(
                    logits.to(dtype=torch.float32)
                )  # explicit fp32 conversion for stability
                _, indices = torch.topk(norm_logits, k=self.topk, dim=1)
            logits = _sinkhorn_activation(logits)
        else:
            logits = _sinkhorn_activation(logits)
            _, indices = torch.topk(logits, k=self.topk, dim=1)
        map = torch.zeros_like(logits).int().scatter(1, indices, 1).bool()
        scores = logits * map
        return scores, map

    def aux_loss_load_balancing(self, logits: torch.Tensor, ae_mask: torch.Tensor):
        """Apply loss-based load balancing to the logits tensor.

        Args:
            logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts].

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment.
            routing_map (torch.Tensor): The mask of token to experts assignment.
        """
        if self.score_function == "soft-topk":
            self.update_soft_topk_alpha()

        probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
            logits,
            self.topk,
            config = self.config,
            layer_number = self.layer_number,
            soft_topk_alpha=self.soft_topk_alpha,
            score_function=self.score_function,
            expert_bias=self.expert_bias,
            training = self.training
        )

        if ae_mask != None and not self.training:
            active_experts = ((routing_map * ae_mask).sum().float() / ae_mask.sum()).to(torch.float32)
            # print('Option 1:', active_experts, ae_mask.sum())
            self.save_topk_metrics(active_experts, logits.device, (routing_map * ae_mask).sum(dim=1))
        else:
            active_experts = (routing_map.sum().float() / routing_map.shape[0]).to(torch.float32)
            # print('Option 2:', active_experts, routing_map.shape[0])
            self.save_topk_metrics(active_experts, logits.device, routing_map.sum(dim=1))

        if self.training:
            # Apply load balancing loss
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
            aux_loss_func = partial(
                switch_load_balancing_loss_func,
                probs=scores,
                tokens_per_expert=tokens_per_expert,
                topk=active_experts
            )
            probs = self.apply_load_balancing_loss(
                activation=probs, load_balancing_loss_func=aux_loss_func
            )
        
        return probs, routing_map
    
    def domain_aux_loss_load_balancing(self, logits: torch.Tensor, classes):
        """Apply loss-based load balancing to the logits tensor.

        Args:
            logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts].

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment.
            routing_map (torch.Tensor): The mask of token to experts assignment.
        """
        probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
            logits,
            self.topk,
            config = self.config,
            layer_number = self.layer_number,
            soft_topk_alpha=self.soft_topk_alpha,
            score_function=self.score_function,
            expert_bias=self.expert_bias,
        )

        if self.training:
            # Apply load balancing loss
            
            # Domain softmax
            scores = torch.zeros_like(logits, dtype=torch.float32)

            for class_num in torch.unique(self.expert_class_map):
                mask = self.expert_class_map == class_num
                scores[:, mask] = torch.softmax(logits[:, mask], dim=-1, dtype=torch.float32)

            #Loss function
            domain_aux_loss_func = partial(
                domain_load_balancing_loss_func,
                probs=scores,
                routing_map=routing_map, 
                tokens_per_expert=tokens_per_expert,
                topk=self.topk,
                expert_class_map=self.expert_class_map,
            )
            probs = self.apply_load_balancing_loss(
                activation=probs, load_balancing_loss_func=domain_aux_loss_func
            )
        return probs, routing_map

    def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length: int):
        """Apply loss-based load balancing to the logits tensor."""

        probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
            logits,
            self.topk,
            config = self.config,
            layer_number = self.layer_number,
            soft_topk_alpha=self.soft_topk_alpha,
            score_function=self.score_function,
            expert_bias=self.expert_bias,
        )

        if self.training:
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
            aux_loss_func = partial(
                sequence_load_balancing_loss_func,
                probs=scores,
                routing_map=routing_map,
                batch_size=bsz,
                seq_length=seq_length,
                topk=self.topk,
            )
            probs = self.apply_load_balancing_loss(
                activation=probs, load_balancing_loss_func=aux_loss_func
            )

        return probs, routing_map

    def apply_load_balancing_loss(
        self, activation: torch.Tensor, load_balancing_loss_func: Callable,
    ):
        """Calculate auxiliary loss, attach gradient function to activation and add to logging."""
        moe_aux_loss_coeff = self.moe_aux_loss_coeff
        if moe_aux_loss_coeff == 0:
            return activation
        sequence_partition_group = None
        if self.config.moe_token_dispatcher_type == "alltoall_seq":
            sequence_partition_group = parallel_state.get_context_parallel_group()
            moe_aux_loss_coeff /= parallel_state.get_tensor_model_parallel_world_size()
        elif parallel_state.get_tensor_and_context_parallel_world_size() > 1:
            sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()

        if self.config.moe_router_aux_loss_warmup:
            _iteration = parallel_state.get_iteration()
            if _iteration['curent_iteration'] < _iteration['target_iterations'] * 0.05:
                moe_aux_loss_coeff *= self.config.moe_router_aux_loss_warmup

        aux_loss = load_balancing_loss_func(
            moe_aux_loss_coeff=moe_aux_loss_coeff, sequence_partition_group=sequence_partition_group
        )
        save_to_aux_losses_tracker(
            "load_balancing_loss",
            aux_loss / moe_aux_loss_coeff,
            self.layer_number,
            self.config.num_layers,
            reduce_group=sequence_partition_group,
        )

        activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
        return activation

    def apply_z_loss(self, logits):
        """Encourages the router's logits to remain small to enhance stability.
        Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

        Args:
            logits (torch.Tensor): The logits of the router.

        Returns:
            torch.Tensor: The logits after applying the z-loss.
        """
        if self.config.moe_z_loss_coeff is not None and self.training:
            moe_z_loss_coeff = (
                self.config.moe_z_loss_coeff
                / parallel_state.get_tensor_and_context_parallel_world_size()
            )
            z_loss = z_loss_func(logits, moe_z_loss_coeff)
            logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
            save_to_aux_losses_tracker(
                "z_loss", z_loss / moe_z_loss_coeff, self.layer_number, self.config.num_layers
            )
        return logits

    def apply_input_jitter(self, input: torch.Tensor):
        """Add noise to the input tensor.
        Refer to https://arxiv.org/abs/2101.03961.

        Args:
            input (Tensor): Input tensor.

        Returns:
            Tensor: Jittered input.
        """
        if self.config.moe_input_jitter_eps is not None:
            eps = self.config.moe_input_jitter_eps
            if self.input_jitter is None:
                self.input_jitter = torch.distributions.uniform.Uniform(
                    torch.tensor(1.0 - eps, device=input.device),
                    torch.tensor(1.0 + eps, device=input.device),
                ).rsample
            return input * self.input_jitter(input.shape)
        else:
            return input

    def transform_classes(self, classes: torch.Tensor, dtype=torch.float32):
        mapping = self.expert_class_map.to(device=classes.device)
        num_of_experts = mapping.shape[0]

        classes = classes.unsqueeze(1).repeat(1, num_of_experts).to(dtype=dtype)
        output_classes = (classes == mapping).to(dtype=dtype)
        return output_classes

    @staticmethod
    def calculate_binary_cross_entropy(clamped_scores: torch.Tensor, classes: torch.Tensor):
        return classes * torch.log(clamped_scores) + (1 - classes) * torch.log(1 - clamped_scores)
    
    def domain_loss_func(self, scores: torch.Tensor, classes: torch.Tensor, masking: torch.Tensor, routing_map: torch.Tensor, eps: torch.float32 = 1e-5):
        ### Make filtering due to calculating loss only for routers to witch tokens are routed
        if self.config.moe_router_domain_loss_masking:
            scores = scores[routing_map].unsqueeze(1)
            classes = classes[routing_map].unsqueeze(1)

        clamped_scores = torch.clamp(scores, min=eps, max=1 - eps)

        domain_loss = TopKRouter.calculate_binary_cross_entropy(clamped_scores, classes)
        
        if self.config.moe_router_domain_loss_masking:
            if masking.any():
                return -torch.mean(domain_loss[masking])
            else:
                return torch.tensor(0.0)
        else:
            return -torch.mean(domain_loss)
        
    def apply_domain_loss(self, logits: torch.Tensor, classes: torch.Tensor, activation: torch.Tensor, routing_map: torch.Tensor):
        masking = classes >= 0 #due to semi-supervised data
        experts_classes = self.transform_classes(classes)

        moe_router_domain_loss_coeff = (
                self.moe_router_domain_loss_coeff
                / parallel_state.get_tensor_and_context_parallel_world_size()
            )
        sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()

        if self.score_function == "softmax":
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
        elif self.score_function == "sigmoid":
            scores = torch.sigmoid(logits).to(torch.float32)
        else:
            raise ValueError(f"Invalid score_function: {self.score_function}")

        domain_loss = self.domain_loss_func(scores, experts_classes, masking, routing_map) * moe_router_domain_loss_coeff

        if self.training:
            activation = MoEAuxLossAutoScaler.apply(activation, domain_loss)

            save_to_aux_losses_tracker(
                "domain_loss",
                domain_loss / moe_router_domain_loss_coeff,
                self.layer_number,
                self.config.num_layers,
                avg_group=sequence_partition_group
            )

            save_to_confusion_tracker(
                self.layer_number,
                classes, 
                routing_map, 
                self.expert_class_map, 
                sequence_partition_group
            )
        
        else:
            save_to_aux_losses_tracker(
                "val_domain_loss",
                domain_loss / moe_router_domain_loss_coeff,
                self.layer_number,
                self.config.num_layers,
                avg_group=sequence_partition_group,
                validation = True,
            )

        self.calculate_f1_score(experts_classes, routing_map, sequence_partition_group)

        return activation
    
    def calculate_f1_score(self, experts_classes: torch.Tensor, routing_map: torch.Tensor, sequence_partition_group: torch.distributed.ProcessGroup = None):
        domain_precision, _, domain_f1 = base_metrics(experts_classes, routing_map)

        if self.training:
            save_to_aux_losses_tracker(
                "domain_precision",
                domain_precision,
                self.layer_number,
                self.config.num_layers,
                avg_group=sequence_partition_group
            )

        else:
            save_to_aux_losses_tracker(
                "val_domain_precision",
                domain_precision,
                self.layer_number,
                self.config.num_layers,
                avg_group=sequence_partition_group,
                validation = True,
            )

    def save_topk_metrics(self, topk, device, activated_experts):

        sequence_partition_group = None
        if self.config.moe_token_dispatcher_type == "alltoall_seq":
            sequence_partition_group = parallel_state.get_context_parallel_group()
            moe_aux_loss_coeff /= parallel_state.get_tensor_model_parallel_world_size()
        elif parallel_state.get_tensor_and_context_parallel_world_size() > 1:
            sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()
        
        topk_name = "active_experts"
        topk_alpha_name = "soft_topk_alpha"
        validation = False

        if not self.training:
            topk_name = "val_active_experts"
            topk_alpha_name = "val_soft_topk_alpha"
            validation = True

        if topk:
            save_to_aux_losses_tracker(
                topk_name,
                topk,
                self.layer_number,
                self.config.num_layers,
                avg_group=sequence_partition_group,
                validation = validation,
                activated_experts = activated_experts
            )
        if self.score_function == "soft-topk":
            soft_topk_alpha = torch.tensor(self.soft_topk_alpha, device=device)
            save_to_aux_losses_tracker(
                topk_alpha_name,
                soft_topk_alpha,
                self.layer_number,
                self.config.num_layers,
                avg_group=sequence_partition_group,
                validation = validation
            )

    def routing(self, logits: torch.Tensor, classes: torch.Tensor, ae_mask: torch.Tensor):
        """Top-k routing function

        Args:
            logits (torch.Tensor): Logits tensor after gating.

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment.
            routing_map (torch.Tensor): The mapping of token to experts assignment,
                with shape [num_tokens, num_experts].
        """
        seq_length, bsz = logits.shape[:2]
        logits = logits.view(-1, self.config.num_moe_experts)
        if classes is not None:
            classes = classes.view(-1)
        if ae_mask is not None:
            ae_mask = ae_mask.view(-1, 1)

        # Apply Z-Loss
        logits = self.apply_z_loss(logits)

        if self.config.moe_token_dispatcher_type == "alltoall_seq":
            # Gather the logits from the TP region
            logits = gather_from_sequence_parallel_region(logits)

        if self.routing_type == "sinkhorn":
            scores, routing_map = self.sinkhorn_load_balancing(logits)
        elif self.routing_type == "aux_loss":
            scores, routing_map = self.aux_loss_load_balancing(logits, ae_mask)
        elif self.routing_type == "domain_aux_loss":
            scores, routing_map = self.domain_aux_loss_load_balancing(logits, classes)
        elif self.routing_type == "seq_aux_loss":
            scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length)
        elif self.routing_type == "none":
            # A naive top-k routing without load balancing
            scores, routing_map, _ = topk_softmax_with_capacity(
                logits,
                self.topk,
                config = self.config,
                layer_number = self.layer_number,
                soft_topk_alpha=self.soft_topk_alpha,
                score_function=self.score_function,
                expert_bias=self.expert_bias,
            )
        else:
            raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
        
        if self.config.moe_router_domain_topk and classes is not None and self.config.moe_token_dispatcher_type == "allgather":
            scores = self.apply_domain_loss(logits, classes, activation = scores, routing_map = routing_map)
            
        # Prevent extra local tokens accumulation on evaluation or activation recomputation
        if (self.enable_expert_bias or self.tokens_distribution_metrics):
            with torch.no_grad():
                self.local_tokens_per_expert += routing_map.sum(dim=0)

        return scores, routing_map

    def forward(self, input: torch.Tensor, classes: torch.Tensor, ae_mask: torch.Tensor):
        """
        Forward pass of the router.

        Args:
            input (torch.Tensor): Input tensor.
        """

        # Apply input jitter
        input = self.apply_input_jitter(input)
        logits = self.gating(input)

        scores, routing_map = self.routing(logits, classes, ae_mask)

        return scores, routing_map
