# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import os
import math
import wandb
import logging
from typing import Optional

import torch

from tools.plot_confusion_matrix import plot_confusion_matrix
from tools.plot_probs_distribution import plot_probs_distribution
from megatron.core import parallel_state
from megatron.core.utils import log_single_rank
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer.moe.soft_topk import soft_top_k
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_config import TransformerConfig

try:
    from megatron.core.extensions.transformer_engine import (
        fused_permute,
        fused_sort_chunks_by_index,
        fused_unpermute,
    )

    HAVE_TE = True
except ImportError:
    HAVE_TE = False

logger = logging.getLogger(__name__)

def switch_load_balancing_loss_func(
    probs: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    topk: int,
    moe_aux_loss_coeff: float,
    sequence_partition_group=None,
):
    """Calculate the auxiliary loss for load balancing.
    Refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.

    Args:
        probs (torch.Tensor): Softmax probabilities output by the router for each token.
                              Shape in [num_tokens, num_experts].
        tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
                                          Shape in [num_experts]
        topk (int): The number of experts selected for each token.
        moe_aux_loss_coeff (float): The coefficient for the auxiliary loss.
        sequence_partition_group (optional): The parallel group over which the sequence is
                                             partitioned. If None, no partitioning is applied.
                                             Defaults to None.

    Returns:
        torch.Tensor: The auxiliary loss for load balancing.
    """
    num_sub_sequence = 1

    # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism
    # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full
    # sequence.
    if sequence_partition_group is not None:
        # We can keep `aggregated_probs_per_expert` local since we don't need the gradient for
        # `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`.
        num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group)
        torch.distributed.all_reduce(tokens_per_expert, group=sequence_partition_group)

    num_tokens = probs.shape[0] * num_sub_sequence
    num_experts = probs.shape[1]

    # The formula of aux_loss: aux_loss = sum((probs_per_expert/num_tokens) *
    # (tokens_per_expert/(num_tokens*topk))) * num_experts * moe_aux_loss_coeff.
    # This can be simplified to fuse the division and multiplication operations.
    aggregated_probs_per_expert = probs.sum(dim=0)
    aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
        num_experts * moe_aux_loss_coeff / (num_tokens * num_tokens * topk)
    )
    return aux_loss

def domain_load_balancing_loss_func(
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    topk: int,
    moe_aux_loss_coeff: float,
    expert_class_map: torch.Tensor,
    sequence_partition_group=None,
):
    if sequence_partition_group is not None:
        torch.distributed.all_reduce(tokens_per_expert, group=sequence_partition_group)
    
    # 1 Compute unique experts and their amount of experts per domain
    unique_experts, num_experts_per_class = torch.unique(expert_class_map, return_counts=True)
    
    # 2 Compute number of tokens routed per domain
    num_domain_tokens = torch.zeros_like(unique_experts, dtype=torch.float).index_add_(0, expert_class_map, tokens_per_expert.float())
    
    # 3 Create one-hot encoding of expert to classes
    expert_class_one_hot = torch.nn.functional.one_hot(expert_class_map, num_classes=unique_experts.shape[0]).to(probs.dtype)

    # 4 Compute sum of probabilities per expert
    marging_routing_map = routing_map.clone()
    
    for cls in unique_experts:
        column_mask = (expert_class_map == cls)  # Find all columns belonging to this class
        marging_routing_map[:, column_mask] = routing_map[:, column_mask].any(dim=1, keepdim=True)  # Broadcast OR result

    aggregated_probs_per_expert = (probs * marging_routing_map).sum(dim=0)

    # 5 Compute auxiliary loss with stability fix
    num_domain_tokens = torch.clamp(num_domain_tokens, min=1)
    denom = num_domain_tokens * num_domain_tokens * topk
    aux_loss = ((aggregated_probs_per_expert * tokens_per_expert) @ expert_class_one_hot) * (num_experts_per_class * moe_aux_loss_coeff / denom)

    # 6 Mask out experts with only one occurrence
    mask1 = num_experts_per_class > 1
    mask2 = num_domain_tokens > 1

    avg_aux_loss = aux_loss[mask1 & mask2].mean()
    # log_single_rank(logger, logging.INFO, f"denom {denom}\n aggregated_probs_per_expert{aggregated_probs_per_expert}\n tokens_per_expert{tokens_per_expert}\n num_experts_per_class {num_experts_per_class}\n aux_loss {aux_loss}")
    # log_single_rank(logger, logging.INFO, f"avg_aux_loss {avg_aux_loss}")

    return avg_aux_loss

def sequence_load_balancing_loss_func(
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    batch_size: int,
    seq_length: int,
    topk: int,
    moe_aux_loss_coeff: float,
    sequence_partition_group=None,
):
    """
    Calculate the auxiliary loss in sequence-level by computing the loss for each individual sample.
    Refer to the DeepSeek-V2 huggingface repo
    (https://huggingface.co/deepseek-ai/DeepSeek-V2) for details.

    Args:
        probs (torch.Tensor): Softmax probabilities output by the router for each token.
                              Shape in [num_tokens, num_experts].
        routing_map (torch.Tensor): Mapping of tokens to experts assignment.
                                    Shape in [num_tokens, num_experts].
        batch_size (int): Batch size to process.
        seq_length (int): Sequence length to process.
        topk (int): Number of experts to route to for each token.
        moe_aux_loss_coeff (float): Scaling coefficient for the auxiliary loss.
        sequence_partition_group (optional): The parallel group over which the sequence is
                                             partitioned. If None, no partitioning is applied.
                                             Defaults to None.

    Returns:
        torch.Tensor: The sequence auxiliary loss for load balancing.
    """
    num_sub_sequence = 1
    num_experts = probs.shape[1]

    probs_for_aux_loss = probs.view(seq_length, batch_size, -1)
    routing_map = routing_map.view(seq_length, batch_size, -1)

    # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism
    # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full
    # sequence.
    if sequence_partition_group is not None:
        num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group)
        seq_length *= num_sub_sequence
        probs_for_aux_loss = gather_from_sequence_parallel_region(
            probs_for_aux_loss, group=sequence_partition_group
        )

    cost_coeff = routing_map.sum(dim=0, dtype=torch.float).div_(seq_length * topk / num_experts)
    seq_aux_loss = (cost_coeff * probs_for_aux_loss.mean(dim=0)).sum(dim=1).mean()
    seq_aux_loss *= moe_aux_loss_coeff

    return seq_aux_loss


def z_loss_func(logits, z_loss_coeff):
    """Encourages the router's logits to remain small to enhance stability.
    Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

    Args:
        logits (torch.Tensor): The logits of the router.

    Returns:
        torch.Tensor: The logits after applying the z-loss.
    """

    z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
    return z_loss


def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
    """Sinkhorn based MoE routing function"""
    cost = torch.exp(cost)
    d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
    d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)

    eps = 0.00000001
    error = 1e9
    d1_old = d1
    while error > tol:
        d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
        d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
        error = torch.mean(torch.abs(d1_old - d1))
        d1_old = d1
    return d1 * cost * d0.unsqueeze(1)


def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None):
    """
    Calculate the capacity of each expert.

    Args:
        num_tokens (int): num of the input tokens.
        num_experts (int): num of the experts.
        capacity_factor (float): Capacity factor.
        min_capacity (int, optional): Minimum capacity. Defaults to None.

    Returns:
        Tensor: Capacity of each expert.
    """
    capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
    if min_capacity is not None and capacity < min_capacity:
        capacity = min_capacity
    return capacity

def base_metrics(y_true, y_pred):
    y_true = y_true.to(dtype=torch.float32)
    y_pred = y_pred.to(dtype=torch.float32)
    
    tp = (y_true * y_pred).sum(dim=1)
    fp = ((1 - y_true) * y_pred).sum(dim=1)
    fn = (y_true * (1 - y_pred)).sum(dim=1)

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)

    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    return precision.mean(), recall.mean(), f1.mean()

class MoEAuxLossAutoScaler(torch.autograd.Function):
    """An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss."""

    main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)

    @staticmethod
    def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
        """Preserve the aux_loss by storing it in the context to avoid garbage collection.

        Args:
            output (torch.Tensor): The output tensor.
            aux_loss (torch.Tensor): The auxiliary loss tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
        ctx.save_for_backward(aux_loss)
        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Compute and scale the gradient for auxiliary loss..

        Args:
            grad_output (torch.Tensor): The gradient of the output.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss
                                               gradient.
        """
        (aux_loss,) = ctx.saved_tensors
        aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
        scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
        return grad_output, scaled_aux_loss_grad

    @staticmethod
    def set_loss_scale(scale: torch.Tensor):
        """set the scale of the aux loss.

        Args:
            scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
                                  matches the scale of the main_loss.
        """
        MoEAuxLossAutoScaler.main_loss_backward_scale = scale

class SoftTopKAlphaScheduler:
    def __init__(self, alphas, schedule_points):
        assert len(schedule_points) == len(alphas)

        self.alphas = alphas
        self.schedule_points = schedule_points
        self.soft_topk_alpha = alphas[0]

    def step(self, current_iter, total_iters):
        if total_iters == 0: return self.alphas[-1]
        t_global = current_iter / total_iters

        for i in range(len(self.schedule_points) - 1):
            t_start = self.schedule_points[i]
            t_end = self.schedule_points[i+1]
            if t_start <= t_global <= t_end:
                a = self.alphas[i]
                b = self.alphas[i+1]

                if t_end == t_start:
                    u = 1.0
                else:
                    u = (t_global - t_start) / (t_end - t_start)

                cos_coeff = (1 - math.cos(u * math.pi)) / 2
                self.soft_topk_alpha = a + (b - a) * cos_coeff
                break

        return self.soft_topk_alpha

def permute(
    tokens,
    routing_map,
    num_out_tokens: Optional[int] = None,
    fused: bool = False,
    drop_and_pad: bool = False,
):
    """Permute the tokens and probs based on the mask.
    Tokens with the same designated expert will be grouped together.
    The shape of mask is [tokens, num_experts], it indicates which experts were selected
    by each token.

    When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
    expert capacity. This function exploits this feature to use ops that support cuda graph.

    Args:
        tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
        routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
        num_out_tokens (int, optional): The number of output tokens. If None, it's set to
                                        the number of input tokens.
        fused (bool, optional): Whether use the fused permute function.
        drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
                                       and pads the number of tokens to the expert capacity.
                                       If set to true, routing_map has a fixed number of non-zeros
                                       in each column.
    """
    if fused:
        if not HAVE_TE or fused_permute is None:
            raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
        return fused_permute(tokens, routing_map, num_out_tokens)

    num_tokens, hidden = tokens.shape
    num_experts = routing_map.shape[1]
    if drop_and_pad and not (num_out_tokens is None):
        capacity = num_out_tokens // num_experts
        assert not routing_map.requires_grad
        # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
        routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
        # use argsort to put indices of all non-zeros in the beginning of list
        # and keep the first `capacity` number of indices
        sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
            :, :capacity
        ].contiguous()
        # flatten from [num_experts, capacity] to 1D
        sorted_indices = sorted_indices.view(-1)
    else:
        # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
        routing_map = routing_map.bool().T.contiguous()

        # Create a dense expert-to-token mapping from the sparse token-to-expert mapping
        token_indices = (
            torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
        )
        sorted_indices = token_indices.masked_select(routing_map)

    # use the mapping to permute the tokens
    permuted_input = tokens.index_select(0, sorted_indices)

    return permuted_input, sorted_indices


def unpermute(
    permuted_tokens: torch.Tensor,
    sorted_indices: torch.Tensor,
    restore_shape: torch.Size,
    probs: torch.Tensor = None,
    routing_map: torch.Tensor = None,
    fused: bool = False,
    drop_and_pad: bool = False,
):
    """
    Restore the original order of tokens after permutation. If probs are provided, it
    will also apply them to the tokens before restoring the order.

    When drop_and_pad=True, the tensors will have the following properties:
      - In routing_map, the number of non-zeros in each column equals to expert capacity
      - The size of sorted_indices equals to num_experts * capacity, each split of `capacity`
        contains the indices of tokens routed to an expert.
    This function exploits these features to use ops that support cuda graph.

    Args:
        permuted_tokens (torch.Tensor): The permuted token tensor.
        sorted_indices (torch.Tensor): The indices used to sort the tokens.
        restore_shape (torch.Size): The shape of the unpermuted tensor.
        probs (torch.Tensor, optional): The unpermuted probs tensor,
        routing_map (torch.Tensor, optional): Token to expert mapping, shape
            [num_tokens, num_experts].
        fused (bool, optional): Whether use the fused unpermute function.
        drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
                                       and pads the number of tokens to the expert capacity.

    Returns:
        torch.Tensor: The tokens restored to their original order.
    """
    if fused:
        if not HAVE_TE or fused_unpermute is None:
            raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
        return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape)

    _, hidden = restore_shape
    input_dtype = permuted_tokens.dtype

    if probs is not None:
        assert routing_map is not None, "Mask must be provided to permute the probs."
        if drop_and_pad:
            num_experts = routing_map.size(1)
            num_permuted_tokens = sorted_indices.size(0)
            capacity = num_permuted_tokens // num_experts
            num_unpermuted_tokens = probs.size(0)

            # [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
            probs_T_1D = probs.T.contiguous().view(-1)

            # get 1D indices of the probs selected by routing_map
            indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
            indices_dim1 = sorted_indices.view(num_experts, capacity)
            indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)

            # get probs from indices
            permuted_probs = probs_T_1D.index_select(0, indices_1D)
        else:
            permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
        # Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in
        # higher precision due to moe_router_dtype being enabled. This can lead to
        # additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory
        # allocation.
        permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)

    # Create an output tensor filled with zeros
    output_tokens = torch.zeros(
        restore_shape, dtype=permuted_tokens.dtype, device=permuted_tokens.device
    )
    # Scatter add the permuted_input back to the original positions
    output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
    return output_tokens.to(dtype=input_dtype)


def sort_chunks_by_idxs(
    input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False
):
    """Split and sort the input tensor based on the split_sizes and sorted indices."""
    if fused:
        if not HAVE_TE or fused_sort_chunks_by_index is None:
            raise ValueError(
                "fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0."
            )
        return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs)

    input = torch.split(input, split_sizes.tolist(), dim=0)
    output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
    return output


def group_limited_topk(
    scores: torch.Tensor,
    topk: int,
    num_tokens: int,
    num_experts: int,
    num_groups: int,
    group_topk: int,
):
    """Perform top-k routing on a subset of expert groups.

    When using group-limited routing:
    1. Experts are divided into 'moe_router_num_groups' equal-sized groups
    2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
       (specifically, the sum of top-2 expert scores within each group)
    3. From these selected groups, 'moe_router_topk' individual experts are chosen

    Two common use cases:
    - Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
      to limit each token to experts on a subset of devices
      (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)

    - Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
      to limit each token to experts on a subset of nodes
      (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)

    Args:
        scores (torch.Tensor): Softmax scores generated by the router.
        topk (int): The number of experts to select for each token.
        num_tokens (int): The number of tokens.
        num_experts (int): The number of experts.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of groups selected for each token.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor.
    """
    # Organize the experts into groups
    group_scores = scores.view(num_tokens, num_groups, -1).topk(2, dim=-1)[0].sum(dim=-1)
    group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
    group_mask = torch.zeros_like(group_scores)
    group_mask.scatter_(1, group_idx, 1)

    # Mask the experts based on selection groups
    score_mask = (
        group_mask.unsqueeze(-1)
        .expand(num_tokens, num_groups, num_experts // num_groups)
        .reshape(num_tokens, -1)
    )

    masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
    probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)

    return probs, top_indices

def compute_soft_topk(scores, topk, alpha = None, threshold = None, hard_threshold_coeff : float = None, layer_number = None, num_experts = None, training = False):
    # Cast scores to float32 for soft_top_k
    scores_fp32 = scores.to(torch.float32)
    
    # Prepare weights for soft_top_k
    if not isinstance(topk, torch.Tensor):
        topk = torch.tensor(topk, device=scores.device, dtype=torch.float32)
    else:
        topk = topk.to(device=scores.device, dtype=torch.float32)

    experts_threshold = math.ceil(topk.detach().item() * hard_threshold_coeff)

    # topk_int = torch.round(topk.detach()).clamp(min=1, max=experts_threshold).long()

    if topk.ndim == 0:
        topk_ = topk.unsqueeze(0)
    else:
        topk_ = topk

    w = topk_.expand(scores_fp32.shape[0]).clone()

    # --- soft top-k probabilities ---
    alpha = torch.as_tensor(alpha, device=scores.device, dtype=torch.float32)
    probs_fp32 = soft_top_k(scores_fp32, w, alpha=alpha)

    # --- hard top-k threshold ---
    topk_vals, _ = torch.topk(probs_fp32.detach(), k=experts_threshold, dim=1)

    lower_bound = topk_vals[:, -1].unsqueeze(1) # experts_threshold expertów at max
    higher_bound = topk_vals[:, 0].unsqueeze(1) # 1 expert at minimum
    # higher_bound = topk_vals[:, topk_int - 1].unsqueeze(1) # topk experts at minimum

    # --- some ajustments ---
    threshold = torch.clamp(
        (threshold * (topk_.detach() / num_experts)).to(scores.device),
        min=lower_bound.to(scores.device),
        max=higher_bound.to(scores.device)
    )

    if training:
        save_to_probs_distribution_tracker(layer_number, probs_fp32.detach(), threshold.detach())

    # --- mask values below threshold ---
    probs_fp32 = torch.where(
        probs_fp32 >= threshold,
        probs_fp32,
        torch.zeros(1, device=scores.device, dtype=torch.float32)
    )

    if training:
        save_to_soft_topk_experts_metrics_tracker(layer_number, probs_fp32.detach(), topk.detach())

    probs = probs_fp32.to(scores.dtype)

    return probs

def topk_softmax_with_capacity(
    logits: torch.Tensor,
    topk: int | torch.Tensor,
    config: TransformerConfig = None,
    layer_number: int = None,
    soft_topk_alpha: Optional[float] = None,
    score_function: str = "softmax",
    expert_bias: Optional[torch.Tensor] = None,
    training: bool = False,
):
    """Apply capacity and padding to the top-k selection.
    Args:
        logits (torch.Tensor): Logits tensor.
        topk (int): The number of experts to select for each token.
        capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
                               of tokens exceeds the capacity.
        pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded
                               tokens will be 0.
        drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
                           If "prob", the tokens with the lowest probabilities will be dropped.
                           If "position", tokens at the end of each batch will be dropped.
        use_pre_softmax (bool): Whether to apply softmax before top-k selection.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of selected groups for each token.
        scaling_factor (float): Scaling factor of routing score in top-k selection.
        deterministic_mode (bool): Deprecated.
        score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
        expert_bias (torch.Tensor): The bias added to logits for expert routing.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
              the routing probabilities for each token to each expert.
            - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
              indicating which experts were selected for each token. True values represent
              the selected experts.
            - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
              the number of local tokens assigned to each expert before dropping and padding.
    """
    assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
    num_tokens, num_experts = logits.shape
    capacity_factor=config.moe_expert_capacity_factor
    pad_to_capacity=config.moe_pad_expert_input_to_capacity
    drop_policy=config.moe_token_drop_policy
    use_pre_softmax=config.moe_router_pre_softmax
    num_groups=config.moe_router_num_groups
    group_topk=config.moe_router_group_topk
    soft_topk_threshold = config.moe_router_soft_topk_routing_scores_threshold
    soft_topk_hard_threshold_coeff = config.moe_router_soft_topk_hard_threshold_coeff
    scaling_factor=config.moe_router_topk_scaling_factor
    num_experts = config.num_moe_experts
    deterministic_mode=config.deterministic_mode

    def compute_topk(scores, topk, num_groups=None, group_topk=None):
        if group_topk:
            return group_limited_topk(
                scores=scores,
                topk=topk,
                num_tokens=num_tokens,
                num_experts=num_experts,
                num_groups=num_groups,
                group_topk=group_topk,
            )
        else:
            if isinstance(topk, float):
                topk = int(topk)
            elif isinstance(topk, torch.Tensor):
                topk = topk.long()
            return torch.topk(scores, k=topk, dim=1)

    if score_function == "softmax":
        if use_pre_softmax:
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
            if expert_bias is not None:
                scores_for_routing = scores + expert_bias
                _, top_indices = compute_topk(scores, topk, num_groups, group_topk)
                probs = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
            else:
                probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        else:
            scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
            probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
    elif score_function == "sigmoid":
        scores = torch.sigmoid(logits)
        if expert_bias is not None:
            scores_for_routing = scores + expert_bias
            _, top_indices = compute_topk(scores, topk, num_groups, group_topk)
            scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
        else:
            scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
    elif score_function == "soft-topk":
        probs = compute_soft_topk(logits, topk, soft_topk_alpha, soft_topk_threshold, soft_topk_hard_threshold_coeff, layer_number, num_experts, training)
    else:
        raise ValueError(f"Invalid score_function: {score_function}")

    if scaling_factor:
        probs = probs * scaling_factor

    if score_function == "soft-topk": 
        topk_masked_gates = probs
        topk_map = (probs > 0).bool()
        tokens_per_expert = topk_map.sum(dim=0)
    else:
        # TODO Try using element-wise operations instead of scatter?
        topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
        topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
        tokens_per_expert = topk_map.sum(dim=0)

    if capacity_factor is None:
        # TopK without capacity
        return topk_masked_gates, topk_map, tokens_per_expert
    else:
        # TopK with capacity
        expert_capacity = get_capacity(
            num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
        )

        # Maskout exceeded tokens
        if drop_policy == "probs":
            _, capacity_indices = torch.topk(
                topk_masked_gates, k=expert_capacity, dim=0, sorted=False
            )
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        elif drop_policy == "position":
            _, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False)
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        else:
            raise ValueError(f"Invalid drop_policy: {drop_policy}")

        if pad_to_capacity:
            final_map = capacity_mask
            final_probs = topk_masked_gates * final_map
        else:
            # Get exceed mask and maskout exceeded probs and indices
            final_map = torch.logical_and(topk_map, capacity_mask)
            final_probs = topk_masked_gates * final_map
        return final_probs, final_map, tokens_per_expert

def save_to_confusion_tracker(layer_num, classes: torch.Tensor, routing_map: torch.Tensor, expert_class_map: torch.Tensor, sequence_partition_group):
    with torch.no_grad():
        device = classes.device
        num_classes = torch.max(expert_class_map) + 1
        routing_map = routing_map.to(dtype=torch.int64)
        conf_matrix = torch.zeros(num_classes, num_classes, dtype=torch.int64, device=device)

        # Find selected expert per sample
        expert_indices = routing_map.nonzero(as_tuple=True)[1] # Get expert indices for each sample
        mapped_routing_map = expert_class_map[expert_indices] # Map expert to its class

        # Remove invalid class assignments (e.g., -1 in classes)
        valid_mask = classes >= 0
        classes = classes[valid_mask].to(dtype=torch.int64)
        mapped_routing_map = mapped_routing_map[valid_mask]

        indices = classes * num_classes + mapped_routing_map  # Convert 2D indices to 1D
        conf_matrix = torch.bincount(indices, minlength=num_classes * num_classes).reshape(num_classes, num_classes).to(device)

        # Add to global tracker
        tracker = parallel_state.get_moe_domain_confusion_matrix_tracker()
        tracker.setdefault("conf_matrixes", {})

        tracker_key = f"conf_matrix_layer_{layer_num}"
        tracker["conf_matrixes"].setdefault(tracker_key, torch.zeros(num_classes, num_classes, device=device))
        tracker["conf_matrixes"][tracker_key] += conf_matrix
        tracker["sequence_partition_group"] = sequence_partition_group

def reduce_conf_matrix():
    tracker = parallel_state.get_moe_domain_confusion_matrix_tracker()

    sequence_partition_group = tracker["sequence_partition_group"]
    for tracker_key in tracker["conf_matrixes"]:
        conf_matrix = tracker["conf_matrixes"][tracker_key]

        torch.distributed.all_reduce(
            conf_matrix, group=parallel_state.get_pipeline_model_parallel_group()
        )
        if sequence_partition_group is not None:
            torch.distributed.all_reduce(conf_matrix, group=sequence_partition_group)

def track_conf_matrix(iteration, wandb_writer, domain_topk: bool):
    if domain_topk:
        tracker = parallel_state.get_moe_domain_confusion_matrix_tracker()
        reduce_conf_matrix()

        # Remove the non-serializable ProcessGroup
        tracker.pop("sequence_partition_group", None)  # Modify in-place

        if torch.distributed.get_rank() == torch.distributed.get_world_size() - 1:

            # Save only serializable data
            output_dir = "./logs/conff_matrix"
            os.makedirs(output_dir, exist_ok=True)

            output_path = os.path.join(output_dir, f"confusion_matrix_{iteration}.pth")
            torch.save(tracker, output_path)

            if wandb_writer:
                # Generate confusion matrix plot images (saves to subfolder)
                plot_confusion_matrix(output_path)

                # Directory where images were saved
                image_dir = os.path.join(output_dir, f"confusion_matrix_{iteration}")

                # Create artifact and add image files
                artifact = wandb.Artifact(f'confusion_matrix_{iteration}', type='dataset')

                for filename in os.listdir(image_dir):
                    if filename.endswith(".png"):
                        file_path = os.path.join(image_dir, filename)
                        artifact.add_file(file_path)

                # Log the artifact to W&B
                wandb_writer.run.log_artifact(artifact)

        clear_conf_matrix_tracker()

def clear_conf_matrix_tracker():
    tracker = parallel_state.get_moe_domain_confusion_matrix_tracker()
    
    for tracker_key in tracker["conf_matrixes"]:
        tracker["conf_matrixes"][tracker_key].zero_()

def save_to_aux_losses_tracker(
    name: str,
    loss: torch.Tensor,
    layer_number: int,
    num_layers: int,
    reduce_group: torch.distributed.ProcessGroup = None,
    avg_group: torch.distributed.ProcessGroup = None,
    validation : bool = False,
    activated_experts: torch.Tensor = None
):
    """Save the auxiliary loss for logging.
    Args:
        name (str): The name of the loss.
        loss (torch.Tensor): The loss tensor.
        layer_number (int): Layer index of the loss.
        num_layers (int): The number of total layers.
        reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
        mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
    """
    # Skip aux loss logging if layer_number is None.
    if layer_number is None:
        return

    if validation:
        tracker = parallel_state.get_val_moe_layer_wise_logging_tracker()
    else:
        tracker = parallel_state.get_moe_layer_wise_logging_tracker()
        
    if name not in tracker:
        tracker[name] = {}
        tracker[name]["values"] = torch.zeros(num_layers, device=loss.device)
        tracker[name]["values_history"] = [[] for i in range(num_layers)]
    tracker[name]["values"][layer_number - 1] += loss.detach()  # Aggregate the loss for the layer.
    if validation:
        if activated_experts is not None:
            tracker[name]["values_history"][layer_number - 1].append(activated_experts.detach())
        else:
            tracker[name]["values_history"][layer_number - 1].append(loss.detach())
    tracker[name]["reduce_group"] = reduce_group
    tracker[name]["avg_group"] = avg_group

def clear_aux_losses_tracker(validation : bool = False):
    """Clear the auxiliary losses."""
    if validation:
        tracker = parallel_state.get_val_moe_layer_wise_logging_tracker()
    else:
        tracker = parallel_state.get_moe_layer_wise_logging_tracker()

    for name in tracker:
        tracker[name]["values"].zero_()
        tracker[name]["reduce_group"] = None
        tracker[name]["avg_group"] = None
        if validation:
            for i in range(len(tracker[name]["values_history"])):
                tracker[name]["values_history"][i] = []

def reduce_aux_losses_tracker_across_ranks(validation : bool = False):
    """Collect and reduce the auxiliary losses across ranks."""
    if validation:
        tracker = parallel_state.get_val_moe_layer_wise_logging_tracker()
    else:
        tracker = parallel_state.get_moe_layer_wise_logging_tracker()

    for name in tracker:
        values = tracker[name]["values"]
        # Collect aux losses across PP.
        torch.distributed.all_reduce(
            values, group=parallel_state.get_pipeline_model_parallel_group()
        )
        # Reduce aux losses across ranks.
        if tracker[name].get('reduce_group') is not None:
            torch.distributed.all_reduce(values, group=tracker[name].get('reduce_group'))
        if tracker[name].get('avg_group') is not None:
            torch.distributed.all_reduce(
                values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG
            )

def track_moe_metrics(
    loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False, validation = False
):
    """Track the MoE metrics for logging."""
    # Aux loss logging
    reduce_aux_losses_tracker_across_ranks(validation = validation)
    
    if validation:
        tracker = parallel_state.get_val_moe_layer_wise_logging_tracker()
    else:
        tracker = parallel_state.get_moe_layer_wise_logging_tracker()

    if writer is not None:
        aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()}
        for name, loss_list in aux_losses.items():
            if total_loss_dict is not None:
                if name not in total_loss_dict:
                    total_loss_dict[name] = loss_list.mean()
                else:
                    total_loss_dict[name] += loss_list.mean()

            # currently when using add_scalars,
            # torch.utils.add_scalars makes each timer its own run, which
            # polutes the runs list, so we just add each as a scalar
            writer.add_scalar(name, loss_list.mean(), iteration)
            if per_layer_logging and not validation:
                for i, loss in enumerate(loss_list.tolist()):
                    writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)

            # W&B logging lacks support for logging multiple scalars simultaneously.
            # As a workaround, we log each scalar individually first, then we can create
            # a custom panel to manually group them to a single plot.
            if wandb_writer:
                wandb_writer.log({f"{name}": loss_list.mean()}, iteration)
                if per_layer_logging and not validation:
                    wandb_writer.log(
                        {
                            f"moe/{name}_layer_{i}": loss
                            for i, loss in enumerate(loss_list.tolist())
                        },
                        iteration,
                    )

    clear_aux_losses_tracker(validation = validation)

def compute_gcv_metrics(tokens_per_expert: torch.Tensor, expert_class_map: torch.Tensor):
    """
    Compute Group Coefficient of Variation (GCV) when groups are given as class labels.
    Coefficient of Variation is the ratio of the standard deviation to the mean
    
    Args:
        tokens_per_expert (torch.Tensor): A tensor containing token counts per expert with shape (layer, dim).
        expert_class_map (torch.Tensor): A tensor where index is expert and value is group/class ID with shape (dim,).

    Returns:
        torch.Tensor: A tensor of computed Group Coefficient of Variation (GCV) per layer.
    """
    #Trick for global calculations, if each expert is same class then its eq to global domain and we can use same code
    if expert_class_map is None:
        expert_class_map = torch.zeros(tokens_per_expert.shape[1], dtype=torch.long, device=tokens_per_expert.device)

    num_classes = torch.max(expert_class_map) + 1  # Get the number of unique groups
    expert_class_one_hot = torch.nn.functional.one_hot(expert_class_map, num_classes=num_classes).to(tokens_per_expert.dtype)  # (expert, class)

    # Expand to (layer, expert, class)
    tokens_per_class = tokens_per_expert.unsqueeze(-1) * expert_class_one_hot  # (layer, expert, class)

    # Compute sum and count per class
    class_token_sums = tokens_per_class.sum(dim=1)  # (layer, class)
    class_token_counts = expert_class_one_hot.sum(dim=0).unsqueeze(0)  # (1, class) -> Number of experts in each class

    # Compute mean per class
    class_token_means = class_token_sums / class_token_counts  # (layer, class)

    # Compute variance per class
    tokens_diff_sq = ((tokens_per_class - class_token_means.unsqueeze(1)) ** 2) * expert_class_one_hot  # (layer, expert, class)
    class_token_vars = tokens_diff_sq.sum(dim=1) / class_token_counts  # (layer, class)

    # Compute standard deviation per class
    class_token_stds = torch.sqrt(class_token_vars)  # (layer, class)

    # Compute GCV per class, avoiding division by zero
    valid = (class_token_means > 0) & (class_token_counts > 1)  # Ignore single-expert classes
    gcv_per_class = torch.where(valid, class_token_stds / class_token_means, torch.zeros_like(class_token_means))  # (layer, class)

    # Weight GCV by number of experts in each class
    weighted_gcv = gcv_per_class * class_token_counts  # (layer, class)

    # Compute total experts per layer (excluding single-expert classes)
    total_valid_experts_per_layer = torch.where(valid, class_token_counts, torch.zeros_like(class_token_counts)).sum(dim=1)  # (layer,)

    # Compute final weighted GCV per layer
    gcv_per_layer = torch.where(total_valid_experts_per_layer > 0, weighted_gcv.sum(dim=1) / total_valid_experts_per_layer, torch.zeros_like(total_valid_experts_per_layer))

    return gcv_per_layer

def save_to_moe_token_dist_tracker(tokens_per_layer_per_expert: torch.tensor, expert_class_map: torch.tensor, domain_tracking: bool, validation: bool = False):
    """Save the auxiliary biases for logging.
    Args:
        - tokens_per_layer_per_expert (torch.Tensor): All agregaded amount of tokens,
                routed per layer and per expert in torch tensor.
    """

    tracker = parallel_state.get_moe_bias_logging_tracker()

    map_key = "map"
    tokens_per_expert_key = "tokens_per_expert"
    global_experets_variance_key = "global_experts_variance"
    domain_experts_variance_key = "domain_experts_variance"

    if validation:
        global_experets_variance_key = "val_" + global_experets_variance_key
        domain_experts_variance_key = "val_" + domain_experts_variance_key

    #Per expert tokens logging
    if not validation:
        if tokens_per_expert_key not in tracker:
            tracker[tokens_per_expert_key] = torch.zeros(tokens_per_layer_per_expert.shape[1], device=tokens_per_layer_per_expert.device)
        
        avg_tokens_per_expert = tokens_per_layer_per_expert.sum(dim=0) / tokens_per_layer_per_expert.shape[0]
        tracker[tokens_per_expert_key].copy_(avg_tokens_per_expert)
    

    #Global Variance logging    
    if global_experets_variance_key not in tracker:
        tracker[global_experets_variance_key] = torch.zeros(tokens_per_layer_per_expert.shape[0], device=tokens_per_layer_per_expert.device)

    global_gcv_values = compute_gcv_metrics(tokens_per_layer_per_expert, None)
    tracker[global_experets_variance_key].copy_(global_gcv_values)
    
    
    #Domain Variance logging  
    if domain_tracking:
        if map_key not in tracker:
            tracker[map_key] = expert_class_map.clone().detach()

        if domain_experts_variance_key not in tracker:
            tracker[domain_experts_variance_key] = torch.zeros(tokens_per_layer_per_expert.shape[0], device=tokens_per_layer_per_expert.device)

        domain_gcv_values = compute_gcv_metrics(tokens_per_layer_per_expert, expert_class_map)
        tracker[domain_experts_variance_key].copy_(domain_gcv_values)
        
def clear_bias_tracker(validation: bool = False):
    """Clear the bias tracker."""
    tracker = parallel_state.get_moe_bias_logging_tracker()

    tokens_per_expert_key = "tokens_per_expert"
    global_experets_variance_key = "global_experts_variance"
    domain_experts_variance_key = "domain_experts_variance"

    if validation:
        global_experets_variance_key = "val_" + global_experets_variance_key
        domain_experts_variance_key = "val_" + domain_experts_variance_key

    if tokens_per_expert_key in tracker:
        tracker[tokens_per_expert_key].zero_()

    if global_experets_variance_key in tracker:
        tracker[global_experets_variance_key].zero_()

    if domain_experts_variance_key in tracker:
        tracker[domain_experts_variance_key].zero_()

def track_bias_metrics(iteration, writer, wandb_writer=None, validation: bool = False):
    bias_tracker = parallel_state.get_moe_bias_logging_tracker()

    map_key = "map"
    tokens_per_expert_key = "tokens_per_expert"
    global_experets_variance_key = "global_experts_variance"
    domain_experts_variance_key = "domain_experts_variance"

    if validation:
        global_experets_variance_key = "val_" + global_experets_variance_key
        domain_experts_variance_key = "val_" + domain_experts_variance_key

    if writer is not None:
        if tokens_per_expert_key in bias_tracker and not validation:
            tokens_per_expert = bias_tracker[tokens_per_expert_key]

            if map_key in bias_tracker:
                expert_class_map = bias_tracker[map_key].clone().detach()
            else: 
                expert_class_map = torch.zeros(tokens_per_expert.shape[0], dtype=torch.long, device=tokens_per_expert.device)


            for idx, tokens in enumerate(tokens_per_expert.tolist()):
                writer.add_scalar(f"moe/expert_{idx}_{expert_class_map[idx].item()}", tokens, iteration)
            
            if wandb_writer:
                wandb_writer.log(
                        {
                            f"moe/expert_{idx}_{expert_class_map[idx].item()}": tokens
                            for idx, tokens in enumerate(tokens_per_expert.tolist())
                        },
                        iteration,
                    )
        
        if global_experets_variance_key in bias_tracker:
            layer_variance = bias_tracker[global_experets_variance_key]

            if wandb_writer:
                wandb_writer.log({global_experets_variance_key: layer_variance.mean()}, iteration)

            if not validation:
                for idx, gcv in enumerate(layer_variance.tolist()):
                    writer.add_scalar(f"moe/glob_variance_layer_{idx}", gcv, iteration)
                
                if wandb_writer:
                    wandb_writer.log(
                            {
                                f"moe/glob_variance_layer_{idx}": gcv
                                for idx, gcv in enumerate(layer_variance.tolist())
                            },
                            iteration,
                        )
        
        if domain_experts_variance_key in bias_tracker:
            layer_variance = bias_tracker[domain_experts_variance_key]

            if wandb_writer:
                wandb_writer.log({domain_experts_variance_key: layer_variance.mean()}, iteration)

            if not validation:
                for idx, gcv in enumerate(layer_variance.tolist()):
                    writer.add_scalar(f"moe/domain_variance_layer_{idx}", gcv, iteration)
                
                if wandb_writer:
                    wandb_writer.log(
                            {
                                f"moe/domain_variance_layer_{idx}": gcv
                                for idx, gcv in enumerate(layer_variance.tolist())
                            },
                            iteration,
                        )

    clear_bias_tracker()

def get_updated_expert_bias(tokens_per_expert, expert_bias, expert_bias_update_rate, moe_router_domain_bias, expert_class_map):
    """Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#

    Args:
        tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert.
        expert_bias (torch.Tensor): The bias for each expert.
        expert_bias_udpate_rate (float): The update rate for the expert bias.
    """
    with torch.no_grad():
        # All Reduce Across TPxCPxDP group
        torch.distributed.all_reduce(
            tokens_per_expert,
            group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
        )

        if not moe_router_domain_bias:
            average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1]
            offset = average_tokens - tokens_per_expert
            updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate
            
        else:
            num_classes = torch.max(expert_class_map) + 1  # Number of unique classes

            # One-hot encoding of expert classes (expert, class)
            expert_class_one_hot = torch.nn.functional.one_hot(expert_class_map, num_classes=num_classes).to(tokens_per_expert.dtype)

            # Expanding dimensions to (layer, expert, class)
            tokens_per_class = tokens_per_expert.unsqueeze(-1) * expert_class_one_hot  # (layer, expert, class)

            # Compute the per-class average token count (avoid zero-experts affecting mean)
            class_token_sums = tokens_per_class.sum(dim=1)  # Sum over experts -> (layer, class)
            class_token_counts = expert_class_one_hot.sum(dim=0).unsqueeze(0)  # Number of experts per class -> (1, class)
            class_token_means = class_token_sums / class_token_counts  # (layer, class)

            # Compute difference to mean per expert
            tokens_diff = (class_token_means.unsqueeze(1) - tokens_per_class) * expert_class_one_hot  # (layer, expert, class)

            # Compute bias updates
            bias_update = expert_bias_update_rate * torch.sign(tokens_diff.sum(dim=-1))  # Sum over class dimension to get back (layer, expert)

            #Update bias
            updated_expert_bias = expert_bias + bias_update
            
        save_to_moe_token_dist_tracker(tokens_per_expert, expert_class_map, moe_router_domain_bias)

        return updated_expert_bias

def update_router_expert_dist_metrics(tokens_per_expert, domain_tracking, expert_class_map, validation: bool = False):
    with torch.no_grad():
        # All Reduce Across TPxCPxDP group
        torch.distributed.all_reduce(
            tokens_per_expert,
            group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
        )

        save_to_moe_token_dist_tracker(tokens_per_expert, expert_class_map, domain_tracking, validation)

def save_to_soft_topk_experts_metrics_tracker(num_layer : int, probs: torch.Tensor, topk: torch.Tensor):
    with torch.no_grad():
        # boolean mask and per-token counts
        probs_bool = probs > 0
        wanted_experts_per_token = probs_bool.sum(dim=1)

        # local stats
        local_max = wanted_experts_per_token.max().item()
        local_min = wanted_experts_per_token.min().item()
        local_sum = wanted_experts_per_token.sum().item()
        local_count = wanted_experts_per_token.numel()
        local_sumsq = (wanted_experts_per_token.float() ** 2).sum().item()

        stats = torch.tensor(
            [local_max, local_min, local_sum, local_count, local_sumsq],
            device=wanted_experts_per_token.device,
            dtype=torch.float32, 
        )

        tracker = parallel_state.get_moe_soft_topk_stats_tracker()
        tracker.setdefault("soft_topk_stats", {})
        tracker.setdefault("soft_topk_avg_probs", {})
        tracker.setdefault("soft_topk_learnable_k", {})

        tracker_key = f"soft_topk_stats_layer_{num_layer}"
        learnable_k_key = f"soft_topk_learnable_k_{num_layer}"

        # accumulate stats
        tracker["soft_topk_stats"].setdefault(tracker_key, torch.zeros(5, device=wanted_experts_per_token.device))
        tracker["soft_topk_stats"][tracker_key] += stats

        tracker["soft_topk_learnable_k"].setdefault(learnable_k_key, torch.zeros(1, device=wanted_experts_per_token.device))
        tracker["soft_topk_learnable_k"][learnable_k_key] += topk

def reduce_soft_topk_expert_metrics():
    tracker = parallel_state.get_moe_soft_topk_stats_tracker()

    for tracker_key in tracker["soft_topk_stats"]:
        soft_topk_stats = tracker["soft_topk_stats"][tracker_key]

        soft_topk_stats = gather_from_sequence_parallel_region(
                soft_topk_stats, group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
            )
        
        tracker["soft_topk_stats"][tracker_key] = soft_topk_stats

    for tracker_key in tracker["soft_topk_learnable_k"]:
        soft_topk_learnable_k_stats = tracker["soft_topk_learnable_k"][tracker_key]

        torch.distributed.all_reduce(
                soft_topk_learnable_k_stats, group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True), op=torch.distributed.ReduceOp.AVG
            )

        tracker["soft_topk_learnable_k"][tracker_key] = soft_topk_learnable_k_stats

def track_soft_topk_expert_metrics(iteration, wandb_writer, moe_loss_scale):
    tracker = parallel_state.get_moe_soft_topk_stats_tracker()
    reduce_soft_topk_expert_metrics()

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    if rank == world_size - 1 and wandb_writer:
        wandb_metrics = {}

        for num_layer, tracker_key in enumerate(tracker["soft_topk_stats"]):
            
            stats = tracker["soft_topk_stats"][tracker_key]
            stats = stats.reshape(world_size, 5)

            maxs = stats[:, 0]
            mins = stats[:, 1]
            sums = stats[:, 2]
            counts = stats[:, 3]
            sumsqs = stats[:, 4]

            # merge into global stats
            counts = counts.to(torch.float64)
            sums = sums.to(torch.float64)
            sumsqs = sumsqs.to(torch.float64)

            total_count = counts.sum()
            total_sum = sums.sum()
            total_sumsq = sumsqs.sum()

            global_mean = total_sum / total_count
            global_var = (total_sumsq - 2 * global_mean * total_sum + total_count * global_mean ** 2) / total_count
            global_var = torch.clamp(global_var, min=0.0)

            global_std = global_var.sqrt().item()
            global_var = global_var.item()
            global_mean = global_mean.item()
            global_max = maxs.max().item()
            global_min = mins.min().item()

            wandb_metrics[f'softtopk/{tracker_key}_max'] = global_max
            wandb_metrics[f'softtopk/{tracker_key}_min'] = global_min
            wandb_metrics[f'softtopk/{tracker_key}_mean'] = global_mean
            wandb_metrics[f'softtopk/{tracker_key}_var'] = global_var
            wandb_metrics[f'softtopk/{tracker_key}_std'] = global_std

        for num_layer, tracker_key in enumerate(tracker["soft_topk_learnable_k"]):
            stats = tracker["soft_topk_learnable_k"][tracker_key] * moe_loss_scale
            wandb_metrics[f'learnable_k/{tracker_key}'] = stats

        wandb_writer.log(
                {
                    key: metric
                    for key, metric in wandb_metrics.items()
                },
                iteration,
            )
    
    clear_soft_topk_expert_metrics_tracker()

def clear_soft_topk_expert_metrics_tracker():
    tracker = parallel_state.get_moe_soft_topk_stats_tracker()

    if "soft_topk_stats" not in tracker:
        return  # nothing to clear

    for tracker_key, value in tracker["soft_topk_stats"].items():
        if torch.is_tensor(value):
            tracker["soft_topk_stats"][tracker_key] = torch.zeros(5, device=value.device)

    for tracker_key, value in tracker["soft_topk_learnable_k"].items():
        if torch.is_tensor(value):
            tracker["soft_topk_learnable_k"][tracker_key].zero_()

def save_to_probs_distribution_tracker(layer_number : int, probs : torch.Tensor, threshold : torch.Tensor):
    with torch.no_grad():
        n_bins = 1000
        bin_edges = torch.linspace(0.0, 1.0, steps=n_bins + 1, device=probs.device)

        values = probs.view(-1)
        values = torch.nan_to_num(values, nan=0.0, posinf=1.0, neginf=0.0)
        values = values.clamp(0.0, 1.0)

        bin_indices = torch.bucketize(values, bin_edges, right=True) - 1

        valid = (bin_indices >= 0) & (bin_indices < n_bins)

        if not valid.all():
            log_single_rank(
                logger,
                logging.ERROR,
                (
                    "[bucketize] Invalid bin indices detected: "
                    f"min={bin_indices.min().item()}, "
                    f"max={bin_indices.max().item()}, "
                    f"invalid_count={(~valid).sum().item()}"
                ),
            )

        bin_indices = bin_indices[valid]

        try:
            bin_counts = torch.bincount(bin_indices, minlength=n_bins)
        except Exception as e:
            print("Exception in torch.bincount:", e)
            print("bin_indices dtype:", bin_indices.dtype)
            print("bin_indices shape:", bin_indices.shape)
            print("bin_indices min/max:", bin_indices.min().item(), bin_indices.max().item())
            print("bin_indices sample:", bin_indices[:10])
            raise

        tracker = parallel_state.get_moe_probs_dist_tracker()
        tracker.setdefault("probs_dist", {})
        tracker.setdefault("threshold", {})

        tracker_key = f"probs_dist_layer_{layer_number}"
        tracker["probs_dist"].setdefault(tracker_key, torch.zeros(n_bins, device=probs.device))
        tracker["probs_dist"][tracker_key] += bin_counts

        tracker_key = f"threshold_layer_{layer_number}"
        tracker["threshold"].setdefault(tracker_key, torch.zeros(1, device=probs.device))
        tracker["threshold"][tracker_key] += threshold.mean()

def reduce_probs_distribution_metrics():
    tracker = parallel_state.get_moe_probs_dist_tracker()

    for tracker_key in tracker["probs_dist"]:
        probs_dist_coutn = tracker["probs_dist"][tracker_key]

        torch.distributed.all_reduce(
            probs_dist_coutn, group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True), op=torch.distributed.ReduceOp.SUM
        )
        
        tracker["probs_dist"][tracker_key] = probs_dist_coutn
    
    for tracker_key in tracker["threshold"]:
        threshold = tracker["threshold"][tracker_key]

        torch.distributed.all_reduce(
            threshold, group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True), op=torch.distributed.ReduceOp.AVG
        )
        
        tracker["threshold"][tracker_key] = threshold 


def track_probs_distribution_metrics(iteration, wandb_writer, iteration_mean):
    reduce_probs_distribution_metrics()
    tracker = parallel_state.get_moe_probs_dist_tracker()

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    if rank == world_size - 1 and wandb_writer:
        # Save only serializable data
        output_dir = "./logs/probs_dist"
        os.makedirs(output_dir, exist_ok=True)

        output_path = os.path.join(output_dir, f"probs_dist_{iteration}.pth")
        torch.save(tracker, output_path)
        
        # Generate confusion matrix plot images (saves to subfolder)
        plot_probs_distribution(output_path, iteration_mean)
        
        if wandb_writer:
            # Directory where images were saved
            image_dir = os.path.join(output_dir, f"probs_dist_{iteration}")

            # Create artifact and add image files
            artifact = wandb.Artifact(f'probs_dist_{iteration}', type='dataset')

            for filename in os.listdir(image_dir):
                if filename.endswith(".png"):
                    file_path = os.path.join(image_dir, filename)
                    artifact.add_file(file_path)

            # Log the artifact to W&B
            wandb_writer.run.log_artifact(artifact)
        
    
    clear_probs_distribution_tracker()

def clear_probs_distribution_tracker():
    tracker = parallel_state.get_moe_probs_dist_tracker()

    if "probs_dist" not in tracker:
        return  # nothing to clear

    for tracker_key, value in tracker["probs_dist"].items():
        if torch.is_tensor(value):
            tracker["probs_dist"][tracker_key].zero_()

    for tracker_key, value in tracker["threshold"].items():
        if torch.is_tensor(value):
            tracker["threshold"][tracker_key].zero_()