                                                              
              
                                                      
                                                                                         

from functools import partial

import torch
from packaging import version

from megatron.core import __version__
from megatron.core.transformer.moe.moe_utils import (
    switch_load_balancing_loss_func,
)

from mpatch.core.transformer.moe.moe_utils import topk_softmax_with_capacity


def aux_loss_load_balancing(self, logits: torch.Tensor):
    """Apply loss-based load balancing to the logits tensor.

    Args:
        logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts].

    Returns:
        probs (torch.Tensor): The probabilities of token to experts assignment.
        routing_map (torch.Tensor): The mask of token to experts assignment.
    """
    probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
        logits,
        self.topk,
        capacity_factor=self.config.moe_expert_capacity_factor,
        pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
        drop_policy=self.config.moe_token_drop_policy,
        use_pre_softmax=self.config.moe_router_pre_softmax,
        num_groups=self.config.moe_router_num_groups,
        group_topk=self.config.moe_router_group_topk,
        scaling_factor=self.config.moe_router_topk_scaling_factor,
        deterministic_mode=self.config.deterministic_mode,
        score_function=self.score_function,
        expert_bias=self.expert_bias,
                          
        moe_norm_topk_prob=self.config.moe_norm_topk_prob,
        moe_norm_topk_prob_eps=self.config.moe_norm_topk_prob_eps,
    )

    if version.parse(__version__) >= version.parse('0.12.0'):
        tmp_flag = self.training and torch.is_grad_enabled()
    else:
        tmp_flag = self.training

    if tmp_flag:
        if version.parse(__version__) < version.parse('0.12.0'):
                                       
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
            _tokens_per_expert = tokens_per_expert
        elif version.parse(__version__) < version.parse('0.13.0'):
                                                 
                                                                                           
            scores = self.compute_routing_scores_for_aux_loss(logits)
            _tokens_per_expert = tokens_per_expert
        else:
            scores, loss_routing_map = self.compute_routing_scores_for_aux_loss(logits)
            _tokens_per_expert = loss_routing_map.sum(dim=0)

        aux_loss_func = partial(
            switch_load_balancing_loss_func,
            probs=scores,
            tokens_per_expert=_tokens_per_expert,
            topk=self.topk,
        )
        probs = self.apply_load_balancing_loss(
            activation=probs, load_balancing_loss_func=aux_loss_func
        )
    return probs, routing_map
