# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import math
from typing import List, Optional, Union

import torch
import torch.distributed
import torch.nn.functional as F

from megatron.core import parallel_state
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region

try:
    import transformer_engine as te  # pylint: disable=unused-import

    from megatron.core.extensions.transformer_engine import (
        fused_permute,
        fused_permute_with_probs,
        fused_sort_chunks_by_index,
        fused_sort_chunks_by_index_with_probs,
        fused_unpermute,
        te_general_gemm,
    )

    HAVE_TE = True
except ImportError:
    HAVE_TE = False


# MOE logging
_MOE_LAYER_WISE_LOGGING_TRACKER = {}

# -----------------------------------------------------------------------------
# PathFinder, PathGuard, PathDPP, and PathRPCA Loss utilities
# -----------------------------------------------------------------------------
# A global dictionary to keep probabilities of the previous MoE layer during a
# forward pass. Keyed by layer index (starting from 0).
_PATHFINDER_PROBS_TRACKER = {}
_PATHGUARD_PROBS_TRACKER = {}
_PATHDPP_PROBS_TRACKER = {}
_PATHRPCA_PROBS_TRACKER = {}
_INTRAGUARD_PROBS_TRACKER = {}

# -----------------------------------------------------------------------------
# Expert Coupling Analysis utilities
# -----------------------------------------------------------------------------
_EXPERT_COUPLING_STATS = {}
_LAST_ROUTING_MAP_ACROSS_LAYERS = {}
_EXPERT_COUPLING_CONFIG = None


def _calc_pathfinder_loss(prev_probs: torch.Tensor, curr_probs: torch.Tensor, coeff: float):
    """Compute PathFinder loss between two consecutive MoE layers.

    Args:
        prev_probs: [num_tokens, num_experts] probabilities from layer L.
        curr_probs: [num_tokens, num_experts] probabilities from layer L+1.
        coeff: Scaling coefficient.

    Returns:
        torch.Tensor: The PathFinder loss (scaled by coeff).
    """

    # Joint probability matrix P_{ij} = 1/T * sum_t g_{ti}^{(L)} g_{tj}^{(L+1)}
    num_tokens = prev_probs.shape[0]
    # [E, E]
    joint = prev_probs.T @ curr_probs
    joint = joint / max(num_tokens, 1)

    # Marginal P_i = sum_j P_{ij}
    marginal = joint.sum(dim=1, keepdim=True)

    eps = 1e-9
    pf_loss = torch.sum(joint * (torch.log(joint + eps) - torch.log(marginal + eps))) 
    pf_loss = pf_loss * coeff
    return pf_loss


def _calc_pathguard_loss(prev_probs: torch.Tensor, curr_probs: torch.Tensor, coeff: float, k: int = 1):
    """Compute PathGuard loss between two consecutive MoE layers.
    
    Updated semantics (fixed):
    - Build a global inter-layer joint matrix J = (prev_probs^T @ curr_probs) / T of shape [E, E].
    - For each token, restrict rows to only the experts activated by that token in the previous layer,
      compute row-wise top-k over destinations j on J for those rows, sum across k and rows, then
      average across tokens and apply -log scaled by coeff.
    
    This preserves the intended row restriction per token, while avoiding the degeneracy that arises
    when taking per-token outer products (which makes the top-k selection independent of the row).

    Args:
        prev_probs: [num_tokens, num_experts] probabilities from layer L.
        curr_probs: [num_tokens, num_experts] probabilities from layer L+1.
        coeff: Scaling coefficient.
        k: Number of top conditional probabilities to penalize per active row (k>=1).

    Returns:
        torch.Tensor: The PathGuard loss (scaled by coeff).
    """
    num_tokens = prev_probs.shape[0]
    num_experts = prev_probs.shape[1]

    if num_tokens == 0 or num_experts == 0 or coeff == 0.0:
        return prev_probs.new_tensor(0.0)

    # Clamp k to valid range
    k = max(1, min(k, num_experts))

    # Global joint matrix J = 1/T * sum_t g_prev(t) g_curr(t)^T = prev^T @ curr / T
    # Shape: [E, E]
    joint_global = prev_probs.T @ curr_probs
    joint_global = joint_global / max(num_tokens, 1)

    # Row-wise top-k sums on the global joint matrix: s_i = sum_topk_j J[i, j]
    # Shape: [E]
    row_topk_vals, _ = joint_global.topk(k, dim=1)
    row_topk_sum = row_topk_vals.sum(dim=1)  # [E]

    # For each token, sum row_topk_sum over only the rows (experts) activated by that token.
    # prev_probs is sparse with non-zeros only on selected experts.
    active_rows = (prev_probs > 0)  # [T, E]
    per_token_sum = (active_rows * row_topk_sum.unsqueeze(0)).sum(dim=1)  # [T]

    # Average over tokens to keep scale comparable
    total = per_token_sum.mean()

    eps = 1e-9
    pg_loss = -total * coeff

    return pg_loss


def _calc_pathdpp_loss(prev_probs: torch.Tensor, curr_probs: torch.Tensor, coeff: float):
    """Compute PathDPP loss between two consecutive MoE layers.
    
    PathDPP uses Determinantal Point Processes (DPP) to maximize routing diversity
    by penalizing similarity among expert routing behaviors, encouraging experts to
    develop distinct routing patterns.

    Args:
        prev_probs: [num_tokens, num_experts] probabilities from layer L.
        curr_probs: [num_tokens, num_experts] probabilities from layer L+1.
        coeff: Scaling coefficient.

    Returns:
        torch.Tensor: The PathDPP loss (scaled by coeff).
    """
    # Joint probability matrix P_{ij} = 1/T * sum_t g_{ti}^{(L)} g_{tj}^{(L+1)}
    num_tokens = prev_probs.shape[0]
    # R matrix = [E, E] - Expert Routing Profile Matrix
    routing_profile_matrix = prev_probs.T @ curr_probs
    routing_profile_matrix = routing_profile_matrix / max(num_tokens, 1)
    
    # Kernel matrix K = R*R^T
    kernel = routing_profile_matrix @ routing_profile_matrix.T
    
    # Add epsilon for numerical stability
    eps = 1e-6
    kernel = kernel + eps * torch.eye(kernel.shape[0], device=kernel.device)
    
    # PathDPP loss is the negative log determinant of the kernel
    # This encourages experts to have diverse routing behaviors
    try:
        # Use Cholesky decomposition for numerical stability in computing log determinant
        chol = torch.linalg.cholesky(kernel)
        log_det = 2 * torch.sum(torch.log(torch.diagonal(chol)))
        dpp_loss = -log_det * coeff
    except Exception:
        # Fallback to eigenvalue-based calculation if Cholesky fails
        eigenvalues = torch.linalg.eigvalsh(kernel)
        log_det = torch.sum(torch.log(torch.clamp(eigenvalues, min=eps)))
        dpp_loss = -log_det * coeff
    
    return dpp_loss


def _calc_pathrpca_loss(prev_probs: torch.Tensor, curr_probs: torch.Tensor, coeff: float,
                       nuclear_coeff: float = 1.0, l1_coeff: float = 1.0):
    """Compute PathRPCA loss between two consecutive MoE layers.
    
    PathRPCA uses an RPCA-like method to minimize the sum of nuclear norm and L1 norm
    of the inter-layer expert probability matrix, transforming the inter-layer expert
    coupling matrix into a low-rank plus sparse form.

    Args:
        prev_probs: [num_tokens, num_experts] probabilities from layer L.
        curr_probs: [num_tokens, num_experts] probabilities from layer L+1.
        coeff: Overall scaling coefficient for the PathRPCA loss.
        nuclear_coeff: Coefficient for the nuclear norm component.
        l1_coeff: Coefficient for the L1 norm component.

    Returns:
        torch.Tensor: The PathRPCA loss (scaled by coeff).
    """
    # Compute joint probability matrix P_{ij} = 1/T * sum_t g_{ti}^{(L)} g_{tj}^{(L+1)}
    num_tokens = prev_probs.shape[0]
    # [E, E] - Inter-layer expert coupling matrix
    coupling_matrix = prev_probs.T @ curr_probs
    coupling_matrix = coupling_matrix / max(num_tokens, 1)
    
    # Compute nuclear norm (sum of singular values)
    # The nuclear norm promotes low-rank structure
    singular_values = torch.linalg.svdvals(coupling_matrix.to(torch.float32))
    nuclear_norm = torch.sum(singular_values)
    # try:
    #     singular_values = torch.linalg.svdvals(coupling_matrix)
    #     nuclear_norm = torch.sum(singular_values)
    # except Exception:
    #     # Fallback: use Frobenius norm as approximation
    #     nuclear_norm = torch.norm(coupling_matrix, 'fro')
    #     print(f"no")
    
    # Compute L1 norm (sum of absolute values)
    # The L1 norm promotes sparsity
    l1_norm = torch.sum(torch.abs(coupling_matrix))
    
    # Combine nuclear norm and L1 norm with their respective coefficients
    rpca_loss = -torch.log(nuclear_coeff * nuclear_norm + l1_coeff * l1_norm) * coeff
    
    return rpca_loss


def compute_and_register_pathfinder_loss(
    layer_number: int,
    probs: torch.Tensor,
    coeff: float,
    sequence_partition_group=None,
):
    """Compute PathFinder loss using probs from previous layer and register current probs.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        probs: Soft probabilities of current layer [num_tokens, num_experts].
        coeff: Scaling coefficient.
        sequence_partition_group (optional): Parallel group for sequence partition.

    Returns:
        Tuple[Optional[torch.Tensor], Optional[int]]: (loss tensor if computed else None,
        the layer index the loss is associated with)
    """

    loss = None
    target_layer = None

    if (layer_number - 1) in _PATHFINDER_PROBS_TRACKER:
        prev_probs = _PATHFINDER_PROBS_TRACKER[layer_number - 1]

        # If sequence parallel, gather to full sequence for correct gradient.
        if sequence_partition_group is not None and sequence_partition_group.size() > 1:
            prev_probs = gather_from_sequence_parallel_region(prev_probs, group=sequence_partition_group)
            probs_full = gather_from_sequence_parallel_region(probs, group=sequence_partition_group)
        else:
            probs_full = probs

        loss = _calc_pathfinder_loss(prev_probs, probs_full, coeff)
        target_layer = layer_number - 1

    # Register current probs for next layer use.
    _PATHFINDER_PROBS_TRACKER[layer_number] = probs

    # Keep tracker size bounded (only need previous layer)
    if (layer_number - 2) in _PATHFINDER_PROBS_TRACKER:
        _PATHFINDER_PROBS_TRACKER.pop(layer_number - 2, None)

    return loss, target_layer


def compute_and_register_pathguard_loss(
    layer_number: int,
    probs: torch.Tensor,
    coeff: float,
    k: int = 1,
    sequence_partition_group=None,
):
    """Compute PathGuard loss using probs from previous layer and register current probs.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        probs: Soft probabilities of current layer [num_tokens, num_experts].
        coeff: Scaling coefficient.
        sequence_partition_group (optional): Parallel group for sequence partition.

    Returns:
        Tuple[Optional[torch.Tensor], Optional[int]]: (loss tensor if computed else None,
        the layer index the loss is associated with)
    """

    loss = None
    target_layer = None

    if (layer_number - 1) in _PATHGUARD_PROBS_TRACKER:
        prev_probs = _PATHGUARD_PROBS_TRACKER[layer_number - 1]

        # If sequence parallel, gather to full sequence for correct gradient.
        if sequence_partition_group is not None and sequence_partition_group.size() > 1:
            prev_probs = gather_from_sequence_parallel_region(prev_probs, group=sequence_partition_group)
            probs_full = gather_from_sequence_parallel_region(probs, group=sequence_partition_group)
        else:
            probs_full = probs

        loss = _calc_pathguard_loss(prev_probs, probs_full, coeff, k)
        target_layer = layer_number - 1

    # Register current probs for next layer use.
    _PATHGUARD_PROBS_TRACKER[layer_number] = probs

    # Keep tracker size bounded (only need previous layer)
    if (layer_number - 2) in _PATHGUARD_PROBS_TRACKER:
        _PATHGUARD_PROBS_TRACKER.pop(layer_number - 2, None)

    return loss, target_layer


def compute_and_register_pathdpp_loss(
    layer_number: int,
    probs: torch.Tensor,
    coeff: float,
    sequence_partition_group=None,
):
    """Compute PathDPP loss using probs from previous layer and register current probs.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        probs: Soft probabilities of current layer [num_tokens, num_experts].
        coeff: Scaling coefficient.
        sequence_partition_group (optional): Parallel group for sequence partition.

    Returns:
        Tuple[Optional[torch.Tensor], Optional[int]]: (loss tensor if computed else None,
        the layer index the loss is associated with)
    """

    loss = None
    target_layer = None

    if (layer_number - 1) in _PATHDPP_PROBS_TRACKER:
        prev_probs = _PATHDPP_PROBS_TRACKER[layer_number - 1]

        # If sequence parallel, gather to full sequence for correct gradient.
        if sequence_partition_group is not None and sequence_partition_group.size() > 1:
            prev_probs = gather_from_sequence_parallel_region(prev_probs, group=sequence_partition_group)
            probs_full = gather_from_sequence_parallel_region(probs, group=sequence_partition_group)
        else:
            probs_full = probs

        loss = _calc_pathdpp_loss(prev_probs, probs_full, coeff)
        target_layer = layer_number - 1

    # Register current probs for next layer use.
    _PATHDPP_PROBS_TRACKER[layer_number] = probs

    # Keep tracker size bounded (only need previous layer)
    if (layer_number - 2) in _PATHDPP_PROBS_TRACKER:
        _PATHDPP_PROBS_TRACKER.pop(layer_number - 2, None)

    return loss, target_layer


def compute_and_register_pathrpca_loss(
    layer_number: int,
    probs: torch.Tensor,
    coeff: float,
    nuclear_coeff: float = 1.0,
    l1_coeff: float = 1.0,
    sequence_partition_group=None,
):
    """Compute PathRPCA loss using probs from previous layer and register current probs.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        probs: Soft probabilities of current layer [num_tokens, num_experts].
        coeff: Overall scaling coefficient for the PathRPCA loss.
        nuclear_coeff: Coefficient for the nuclear norm component.
        l1_coeff: Coefficient for the L1 norm component.
        sequence_partition_group (optional): Parallel group for sequence partition.

    Returns:
        Tuple[Optional[torch.Tensor], Optional[int]]: (loss tensor if computed else None,
        the layer index the loss is associated with)
    """

    loss = None
    target_layer = None

    if (layer_number - 1) in _PATHRPCA_PROBS_TRACKER:
        prev_probs = _PATHRPCA_PROBS_TRACKER[layer_number - 1]

        # If sequence parallel, gather to full sequence for correct gradient.
        if sequence_partition_group is not None and sequence_partition_group.size() > 1:
            prev_probs = gather_from_sequence_parallel_region(prev_probs, group=sequence_partition_group)
            probs_full = gather_from_sequence_parallel_region(probs, group=sequence_partition_group)
        else:
            probs_full = probs

        loss = _calc_pathrpca_loss(prev_probs, probs_full, coeff, nuclear_coeff, l1_coeff)
        target_layer = layer_number - 1

    # Register current probs for next layer use.
    _PATHRPCA_PROBS_TRACKER[layer_number] = probs

    # Keep tracker size bounded (only need previous layer)
    if (layer_number - 2) in _PATHRPCA_PROBS_TRACKER:
        _PATHRPCA_PROBS_TRACKER.pop(layer_number - 2, None)

    return loss, target_layer


def _calc_marginguard_loss(logits: torch.Tensor, coeff: float, k: int = 1):
    """Compute MarginGuard loss for maximizing the score gap between K-th selected and (K+1)-th discarded expert.
    
    MarginGuard aims to maximize the score (logits) gap between the selected K-th expert and the 
    discarded (K+1)-th expert. This forces the gating network to make high-confidence decisions,
    reducing ambiguity near the decision boundary.

    Args:
        logits: [num_tokens, num_experts] raw scores (logits) from the gating network.
        coeff: Scaling coefficient.
        k: Number of top experts to consider (K value in the algorithm).

    Returns:
        torch.Tensor: The MarginGuard loss (scaled by coeff).
    """
    num_tokens = logits.shape[0]
    num_experts = logits.shape[1]
    
    if num_tokens == 0 or k >= num_experts:
        return torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
    
    # Sort logits in descending order for each token
    # sorted_logits: [num_tokens, num_experts] where sorted_logits[t, i] = s'_{t, (i+1)}
    sorted_logits, _ = torch.sort(logits, dim=1, descending=True)
    
    # Get the K-th selected expert score: s'_{t, (K)}
    kth_selected_score = sorted_logits[:, k-1]  # [num_tokens]
    
    # Get the (K+1)-th discarded expert score: s'_{t, (K+1)}
    kplus1_discarded_score = sorted_logits[:, k]  # [num_tokens]
    
    # MarginGuard loss for each token: s'_{t, (K+1)} - s'_{t, (K)}
    # We want to minimize this (maximize the margin s'_{t, (K)} - s'_{t, (K+1)})
    token_losses = kplus1_discarded_score - kth_selected_score  # [num_tokens]
    
    # Average over all tokens in the batch
    mg_loss = torch.mean(token_losses) * coeff
    
    return mg_loss


def _calc_intraguard_loss(probs: torch.Tensor, routing_map: torch.Tensor, coeff: float, k: int = 1):
    """Compute IntraGuard loss for intra-layer expert coupling.
    
    IntraGuard focuses on encouraging stable expert "partnerships" within a single layer,
    by maximizing the conditional probability that tokens routed to expert i are also
    routed to specific expert j (where j != i).

    Args:
        probs: [num_tokens, num_experts] soft probabilities from the router.
        routing_map: [num_tokens, num_experts] binary routing decisions (0 or 1).
        coeff: Scaling coefficient.
        k: Number of top conditional probabilities to consider (k=1 replicates original formulation).

    Returns:
        torch.Tensor: The IntraGuard loss (scaled by coeff).
    """
    num_tokens = routing_map.shape[0]
    num_experts = routing_map.shape[1]
    
    if num_tokens == 0:
        return torch.tensor(0.0, device=probs.device, dtype=probs.dtype)
    
    # Calculate marginal probabilities P_i: probability that a token is routed to expert i
    # P_i = Count(tokens routed to expert i) / Total number of tokens
    # Convert boolean routing_map to float for calculations
    routing_map_float = routing_map.float()
    marginal_probs = routing_map_float.sum(dim=0) / max(num_tokens, 1)  # [num_experts]
    
    # Calculate joint probabilities P_ij: probability that a token is routed to both expert i and j
    # P_ij = Count(tokens routed to both expert i and expert j) / Total number of tokens
    # routing_map.T @ routing_map gives us the co-occurrence count matrix
    joint_counts = routing_map_float.T @ routing_map_float  # [num_experts, num_experts]
    joint_probs = joint_counts.float() / max(num_tokens, 1)  # [num_experts, num_experts]
    
    # For IntraGuard loss, we want to maximize the sum of top-k joint probabilities for each expert
    # This encourages stable expert partnerships
    
    if k == 1:
        # Original IntraGuard formulation: L_ig = sum_i max_{j != i} P_ij
        # Set diagonal to 0 to exclude self-coupling (i == j)
        joint_probs_no_diag = joint_probs.clone()
        joint_probs_no_diag.fill_diagonal_(0.0)
        
        # Find the maximum joint probability for each expert (strongest partnership)
        max_joint_probs, _ = joint_probs_no_diag.max(dim=1)  # [num_experts]
        
        # Sum over all experts to get the total IntraGuard loss
        intraguard_loss = max_joint_probs.sum()
    else:
        # Top-K extension: L_ig_topK = sum_i sum_{j in Top-(K-1) partners of i} P_ij
        joint_probs_no_diag = joint_probs.clone()
        joint_probs_no_diag.fill_diagonal_(0.0)
        
        # For each expert i, find the top k-1 partners (since we exclude self-coupling)
        k_partners = min(k, num_experts - 1)
        if k_partners > 0:
            topk_joint_probs, _ = joint_probs_no_diag.topk(k_partners, dim=1)  # [num_experts, k_partners]
            intraguard_loss = topk_joint_probs.sum()
        else:
            intraguard_loss = torch.tensor(0.0, device=probs.device, dtype=probs.dtype)
    
    # Since we want to maximize the IntraGuard loss, we return the negative value
    # (so that minimizing the negative maximizes the original)
    return -torch.log(intraguard_loss) * coeff


def compute_and_register_intraguard_loss(
    layer_number: int,
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    coeff: float,
    k: int = 1,
    sequence_partition_group=None,
):
    """Compute IntraGuard loss for the current layer.
    
    Unlike PathGuard which computes inter-layer losses, IntraGuard computes intra-layer
    expert coupling loss using the routing decisions within the same layer.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        probs: Soft probabilities of current layer [num_tokens, num_experts].
        routing_map: Binary routing decisions [num_tokens, num_experts].
        coeff: Scaling coefficient.
        k: Number of top conditional probabilities to consider.
        sequence_partition_group (optional): Parallel group for sequence partition.

    Returns:
        Tuple[Optional[torch.Tensor], Optional[int]]: (loss tensor if computed else None,
        the layer index the loss is associated with)
    """
    
    # If sequence parallel, gather to full sequence for correct gradient.
    if sequence_partition_group is not None and sequence_partition_group.size() > 1:
        probs_full = gather_from_sequence_parallel_region(probs, group=sequence_partition_group)
        routing_map_full = gather_from_sequence_parallel_region(routing_map, group=sequence_partition_group)
    else:
        probs_full = probs
        routing_map_full = routing_map
    
    # Compute IntraGuard loss for the current layer
    loss = _calc_intraguard_loss(probs_full, routing_map_full, coeff, k)
    target_layer = layer_number
    
    return loss, target_layer


def compute_and_register_marginguard_loss(
    layer_number: int,
    logits: torch.Tensor,
    coeff: float,
    k: int = 1,
    sequence_partition_group=None,
):
    """Compute MarginGuard loss for the current layer.
    
    MarginGuard computes intra-layer loss using the raw logits from the gating network
    to maximize the margin between the K-th selected and (K+1)-th discarded expert.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        logits: Raw logits from the gating network [num_tokens, num_experts].
        coeff: Scaling coefficient.
        k: Number of top experts to consider (K value in the algorithm).
        sequence_partition_group (optional): Parallel group for sequence partition.

    Returns:
        Tuple[Optional[torch.Tensor], Optional[int]]: (loss tensor if computed else None,
        the layer index the loss is associated with)
    """
    
    # If sequence parallel, gather to full sequence for correct gradient.
    if sequence_partition_group is not None and sequence_partition_group.size() > 1:
        logits_full = gather_from_sequence_parallel_region(logits, group=sequence_partition_group)
    else:
        logits_full = logits
    
    # Compute MarginGuard loss for the current layer
    loss = _calc_marginguard_loss(logits_full, coeff, k)
    target_layer = layer_number
    
    return loss, target_layer


def _calc_gvo_loss(
    dispatched_hidden: torch.Tensor,
    permuted_probs: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    expert_w2: torch.Tensor,
    topk: int,
    eps: float = 1e-8,
    coeff: float = 0.0,
):
    """Compute Gate Vector Orthogonalization loss in a vectorized manner.

    dispatched_hidden: concatenated token representations per expert, shape [T_all, H]
    permuted_probs: concatenated top-k probs aligned with dispatched_hidden, shape [T_all, K]
    tokens_per_expert: number of tokens routed to each local expert, shape [E_local]
    expert_w2: local experts' output projection weights stacked, shape [E_local, D_ffn, H]
    topk: number of activated experts per token
    eps: small constant for numerical stability
    coeff: loss scaling coefficient (already scaled for groups upstream)
    """
    if coeff == 0.0 or topk < 2 or dispatched_hidden.numel() == 0:
        return None

    # Build per-token expert index for each of the K selections: [T_all, K]
    # We infer aligned expert indices from tokens_per_expert chunking order and top-k columns.
    # token_dispatcher ensures permuted tensors are grouped by expert; probs columns correspond to the selected experts per token.
    # To get gating vectors g_i(x) = x W2_i (with W2 per expert), we materialize them expert-wise.

    # Split dispatched_hidden and permuted_probs into experts
    e_splits = tokens_per_expert.tolist()
    if len(e_splits) == 0 or sum(e_splits) == 0:
        return None
    hidden_chunks = torch.split(dispatched_hidden, e_splits, dim=0)
    prob_chunks = torch.split(permuted_probs, e_splits, dim=0)

    total_loss = dispatched_hidden.new_zeros(())
    for e_idx, (h_e, p_e) in enumerate(zip(hidden_chunks, prob_chunks)):
        if h_e.numel() == 0:
            continue
        # expert weight W2_e: [D_ffn, H] mapping hidden -> output gating space d (use H for alignment as per definition)
        # In expert implementations, the second projection maps ffn->H; we use W2_e^T to produce d=H projections from x
        # Here, following the provided definition g_i(x) = x W_{2,i}
        W2_e = expert_w2[e_idx]  # [D_ffn, H] in GroupedMLP; we want [H, d]; use transpose if needed
        if W2_e.shape[1] != h_e.shape[1]:
            # If W2_e is [D_ffn, H], we multiply h_e @ W2_e to get [T_e, H]
            g = h_e @ W2_e  # [T_e, H]
        else:
            # If already [H, d], compute x @ W2
            g = h_e @ W2_e  # [T_e, d]

        # p_e is [T_e, K]. For G computed from a single expert weight, we need the set of gating vectors for K experts.
        # We approximate by reusing g for this expert across K columns; we need K experts' projections. To avoid heavy cross-expert matmuls,
        # we only compute pairwise similarity among the K columns of g weighted by probs as a surrogate diversity signal.
        # Form normalized per-column vectors via weighted views: G_hat = normalize(p_e.unsqueeze(-1) * g.unsqueeze(1), dim=-1)
        # Shape: [T_e, K, d]
        G = p_e.unsqueeze(-1) * g.unsqueeze(1)
        G_norm = torch.linalg.norm(G, dim=-1, keepdim=True).clamp_min(eps)
        G_hat = G / G_norm
        # Cosine sim among columns: [T_e, K, K]
        C = torch.matmul(G_hat, G_hat.transpose(-1, -2))
        # Upper triangle excluding diagonal
        triu = torch.triu(C, diagonal=1)
        loss_e = (triu * triu).sum()
        total_loss = total_loss + loss_e

    return coeff * total_loss


def compute_and_register_gvo_loss(
    layer_number: int,
    dispatched_hidden: torch.Tensor,
    permuted_probs: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    expert_w2: torch.Tensor,
    coeff: float,
    topk: int,
    sequence_partition_group=None,
):
    """Compute GVO loss and reduce/register like other aux losses.

    Returns (loss, target_layer_number) or (None, None) when inactive.
    """
    if coeff == 0.0:
        return None, None
    loss = _calc_gvo_loss(
        dispatched_hidden=dispatched_hidden,
        permuted_probs=permuted_probs,
        tokens_per_expert=tokens_per_expert,
        expert_w2=expert_w2,
        topk=topk,
        coeff=coeff,
    )
    if loss is None:
        return None, None
    if sequence_partition_group is not None:
        torch.distributed.all_reduce(loss, group=sequence_partition_group)
    return loss, layer_number


def compute_and_register_gvo_loss_router(
    layer_number: int,
    hidden_states: torch.Tensor,
    local_routing_map: torch.Tensor,
    expert_w2_local: torch.Tensor,
    coeff: float,
    sequence_partition_group=None,
):
    """Compute GVO loss from router outputs using local expert weights.

    hidden_states: [T, H]
    local_routing_map: [T, E_local] boolean/int with top-k selections intersect local experts
    expert_w2_local: [E_local, D_ffn, H]
    coeff: scaling coefficient
    """
    if coeff == 0.0 or hidden_states.numel() == 0:
        return None, None
    T, H = hidden_states.shape
    E_local = expert_w2_local.shape[0]
    if E_local == 0:
        return None, None
    # Compute gating vectors for all local experts in one go: [T, E_local, D_ffn]
    # g_e(x) = x @ W2_e^T
    W = expert_w2_local.transpose(1, 2)  # [E_local, H, D_ffn]
    G = torch.einsum('th,ehd->ted', hidden_states, W)
    # Select top-k among local experts via mask and compute cosine similarities
    M = local_routing_map.to(G.dtype)  # [T, E_local]
    # Normalize along d
    G_norm = torch.linalg.norm(G, dim=-1, keepdim=True).clamp_min(1e-8)
    G_hat = G / G_norm  # [T, E_local, D_ffn]
    # Cosine similarity matrix per token: [T, E_local, E_local]
    C = torch.matmul(G_hat, G_hat.transpose(-1, -2))
    # Apply mask to keep only selected experts per token
    M_ij = (M.unsqueeze(-1) * M.unsqueeze(-2))  # [T, E_local, E_local]
    C = C * M_ij
    triu = torch.triu(C, diagonal=1)
    loss = (triu * triu).sum() * coeff
    if sequence_partition_group is not None:
        torch.distributed.all_reduce(loss, group=sequence_partition_group)
    return loss, layer_number


def compute_and_register_rvo_loss_router(
    layer_number: int,
    hidden_states: torch.Tensor,
    local_routing_map: torch.Tensor,
    expert_w1_local: torch.Tensor,
    expert_w2_local: torch.Tensor,
    coeff: float,
    use_swiglu: bool = True,
    sequence_partition_group=None,
):
    """Compute RVO loss from router outputs using local expert weights.

    RVO compares cosine similarity between full FFN outputs from the activated experts.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        hidden_states: [T, H]
        local_routing_map: [T, E_local] boolean/int with top-k selections intersect local experts
        expert_w1_local: [E_local, H, D_fc1] where D_fc1 = 2*D_ffn for SwiGLU else D_ffn
        expert_w2_local: [E_local, D_ffn, H]
        coeff: scaling coefficient
        use_swiglu: if True, apply SwiGLU; otherwise apply GELU on fc1 activations
        sequence_partition_group: optional communication group for reduction

    Returns:
        (loss, target_layer_number) or (None, None) when inactive.
    """
    if coeff == 0.0 or hidden_states.numel() == 0:
        return None, None
    T, H = hidden_states.shape
    E_local = expert_w2_local.shape[0]
    if E_local == 0:
        return None, None

    # FC1: [T, E, D_fc1]
    # Z_e = x @ W1_e where W1_e: [H, D_fc1]
    Z = torch.einsum('th,ehd->ted', hidden_states, expert_w1_local)

    # Activation
    if use_swiglu:
        # Split into a and b, then silu(a) * b
        D_fc1 = Z.shape[-1]
        assert D_fc1 % 2 == 0, "SwiGLU expects fc1 dim to be even (2*ffn)."
        a, b = torch.split(Z, D_fc1 // 2, dim=-1)
        A = F.silu(a) * b
    else:
        A = F.gelu(Z)

    # FC2: [T, E, H]
    # Y_e = A_e @ W2_e where W2_e: [D_ffn, H]
    Y = torch.einsum('ted,edh->teh', A, expert_w2_local)

    # Normalize along hidden dimension H for cosine similarities
    Y_norm = torch.linalg.norm(Y, dim=-1, keepdim=True).clamp_min(1e-8)
    Y_hat = Y / Y_norm  # [T, E, H]

    # Cosine similarity matrix per token across experts: [T, E, E]
    C = torch.matmul(Y_hat, Y_hat.transpose(-1, -2))

    # Apply selection mask for activated experts only
    M = local_routing_map.to(Y.dtype)  # [T, E]
    M_ij = (M.unsqueeze(-1) * M.unsqueeze(-2))  # [T, E, E]
    C = C * M_ij

    triu = torch.triu(C, diagonal=1)
    loss = (triu * triu).sum() * coeff
    if sequence_partition_group is not None:
        torch.distributed.all_reduce(loss, group=sequence_partition_group)
    return loss, layer_number


def compute_and_register_hvo_loss_router(
    layer_number: int,
    hidden_states: torch.Tensor,
    local_routing_map: torch.Tensor,
    expert_w1_local: torch.Tensor,
    coeff: float,
    sequence_partition_group=None,
):
    """Compute HVO loss from router outputs using local expert weights.

    HVO compares cosine similarity between the SwiGLU pre-output vectors
    A_e(x) = silu(x @ W1_e[:,:D_ffn]) * (x @ W1_e[:,D_ffn:]) across the activated experts.

    Args:
        layer_number: Index of current MoE layer (starting from 0).
        hidden_states: [T, H]
        local_routing_map: [T, E_local] boolean/int with top-k selections intersect local experts
        expert_w1_local: [E_local, H, D_fc1] where D_fc1 = 2*D_ffn for SwiGLU
        coeff: scaling coefficient
        sequence_partition_group: optional communication group for reduction

    Returns:
        (loss, target_layer_number) or (None, None) when inactive.
    """
    if coeff == 0.0 or hidden_states.numel() == 0:
        return None, None
    T, H = hidden_states.shape
    E_local = expert_w1_local.shape[0]
    if E_local == 0:
        return None, None

    # FC1 pre-activation: Z_e = x @ W1_e where W1_e: [H, D_fc1]
    Z = torch.einsum('th,ehd->ted', hidden_states, expert_w1_local)

    # SwiGLU pre-output vector: A = silu(a) * b
    D_fc1 = Z.shape[-1]
    if D_fc1 % 2 != 0:
        # Not a SwiGLU configuration; skip HVO computation
        return None, None
    a, b = torch.split(Z, D_fc1 // 2, dim=-1)
    A = F.silu(a) * b  # [T, E, D_ffn]

    # Normalize along D_ffn for cosine similarities
    A_norm = torch.linalg.norm(A, dim=-1, keepdim=True).clamp_min(1e-8)
    A_hat = A / A_norm  # [T, E, D_ffn]

    # Cosine similarity matrix per token across experts: [T, E, E]
    C = torch.matmul(A_hat, A_hat.transpose(-1, -2))
    # Mask to keep only selected experts pairs per token
    M = local_routing_map.to(A.dtype)  # [T, E]
    M_ij = (M.unsqueeze(-1) * M.unsqueeze(-2))
    C = C * M_ij

    triu = torch.triu(C, diagonal=1)
    loss = (triu * triu).sum() * coeff
    if sequence_partition_group is not None:
        torch.distributed.all_reduce(loss, group=sequence_partition_group)
    return loss, layer_number


def switch_load_balancing_loss_func(
    probs: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    topk: int,
    moe_aux_loss_coeff: float,
    sequence_partition_group=None,
):
    """Calculate the auxiliary loss for load balancing.
    Refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.

    Args:
        probs (torch.Tensor): Softmax probabilities output by the router for each token.
                              Shape in [num_tokens, num_experts].
        tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
                                          Shape in [num_experts]
        topk (int): The number of experts selected for each token.
        moe_aux_loss_coeff (float): The coefficient for the auxiliary loss.
        sequence_partition_group (optional): The parallel group over which the sequence is
                                             partitioned. If None, no partitioning is applied.
                                             Defaults to None.

    Returns:
        torch.Tensor: The auxiliary loss for load balancing.
    """
    num_sub_sequence = 1

    # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism
    # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full
    # sequence.
    if sequence_partition_group is not None:
        # We can keep `aggregated_probs_per_expert` local since we don't need the gradient for
        # `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`.
        num_sub_sequence = sequence_partition_group.size()
        torch.distributed.all_reduce(tokens_per_expert, group=sequence_partition_group)

    num_tokens = probs.shape[0] * num_sub_sequence
    num_experts = probs.shape[1]

    # The formula of aux_loss: aux_loss = sum((probs_per_expert/num_tokens) *
    # (tokens_per_expert/(num_tokens*topk))) * num_experts * moe_aux_loss_coeff.
    # This can be simplified to fuse the division and multiplication operations.
    aggregated_probs_per_expert = probs.sum(dim=0)
    aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
        num_experts * moe_aux_loss_coeff / (num_tokens * num_tokens * topk)
    )
    return aux_loss


def sequence_load_balancing_loss_func(
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    batch_size: int,
    seq_length: int,
    topk: int,
    moe_aux_loss_coeff: float,
    sequence_partition_group=None,
):
    """
    Calculate the auxiliary loss in sequence-level by computing the loss for each individual sample.
    Refer to the DeepSeek-V2 huggingface repo
    (https://huggingface.co/deepseek-ai/DeepSeek-V2) for details.

    Args:
        probs (torch.Tensor): Softmax probabilities output by the router for each token.
                              Shape in [num_tokens, num_experts].
        routing_map (torch.Tensor): Mapping of tokens to experts assignment.
                                    Shape in [num_tokens, num_experts].
        batch_size (int): Batch size to process.
        seq_length (int): Sequence length to process.
        topk (int): Number of experts to route to for each token.
        moe_aux_loss_coeff (float): Scaling coefficient for the auxiliary loss.
        sequence_partition_group (optional): The parallel group over which the sequence is
                                             partitioned. If None, no partitioning is applied.
                                             Defaults to None.

    Returns:
        torch.Tensor: The sequence auxiliary loss for load balancing.
    """
    num_sub_sequence = 1
    num_experts = probs.shape[1]

    probs_for_aux_loss = probs.view(seq_length, batch_size, -1)
    routing_map = routing_map.view(seq_length, batch_size, -1)

    # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism
    # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full
    # sequence.
    if sequence_partition_group is not None:
        num_sub_sequence = sequence_partition_group.size()
        seq_length *= num_sub_sequence
        probs_for_aux_loss = gather_from_sequence_parallel_region(
            probs_for_aux_loss, group=sequence_partition_group
        )

    cost_coeff = routing_map.sum(dim=0, dtype=torch.float).div_(seq_length * topk / num_experts)
    seq_aux_loss = (cost_coeff * probs_for_aux_loss.mean(dim=0)).sum(dim=1).mean()
    seq_aux_loss *= moe_aux_loss_coeff

    return seq_aux_loss


def z_loss_func(logits, z_loss_coeff):
    """Encourages the router's logits to remain small to enhance stability.
    Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

    Args:
        logits (torch.Tensor): The logits of the router.

    Returns:
        torch.Tensor: The logits after applying the z-loss.
    """

    z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
    return z_loss


def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
    """Sinkhorn based MoE routing function"""
    cost = torch.exp(cost)
    d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
    d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)

    eps = 0.00000001
    error = 1e9
    d1_old = d1
    while error > tol:
        d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
        d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
        error = torch.mean(torch.abs(d1_old - d1))
        d1_old = d1
    return d1 * cost * d0.unsqueeze(1)


def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None):
    """
    Calculate the capacity of each expert.

    Args:
        num_tokens (int): num of the input tokens.
        num_experts (int): num of the experts.
        capacity_factor (float): Capacity factor.
        min_capacity (int, optional): Minimum capacity. Defaults to None.

    Returns:
        Tensor: Capacity of each expert.
    """
    capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
    if min_capacity is not None and capacity < min_capacity:
        capacity = min_capacity
    return capacity


class MoEAuxLossAutoScaler(torch.autograd.Function):
    """An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss."""

    main_loss_backward_scale: torch.Tensor = None

    @staticmethod
    def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
        """Preserve the aux_loss by storing it in the context to avoid garbage collection.

        Args:
            output (torch.Tensor): The output tensor.
            aux_loss (torch.Tensor): The auxiliary loss tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
        ctx.save_for_backward(aux_loss)
        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Compute and scale the gradient for auxiliary loss..

        Args:
            grad_output (torch.Tensor): The gradient of the output.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss
                                               gradient.
        """
        (aux_loss,) = ctx.saved_tensors
        if MoEAuxLossAutoScaler.main_loss_backward_scale is None:
            MoEAuxLossAutoScaler.main_loss_backward_scale = torch.tensor(
                1.0, device=aux_loss.device
            )
        aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
        scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
        return grad_output, scaled_aux_loss_grad

    @staticmethod
    def set_loss_scale(scale: torch.Tensor):
        """set the scale of the aux loss.

        Args:
            scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
                                  matches the scale of the main_loss.
        """
        if MoEAuxLossAutoScaler.main_loss_backward_scale is None:
            MoEAuxLossAutoScaler.main_loss_backward_scale = scale
        else:
            MoEAuxLossAutoScaler.main_loss_backward_scale.copy_(scale)


def permute(
    tokens,
    routing_map,
    probs: Optional[torch.Tensor] = None,
    num_out_tokens: Optional[int] = None,
    fused: bool = False,
    drop_and_pad: bool = False,
):
    """Permute the tokens and probs based on the mask.
    Tokens with the same designated expert will be grouped together.
    The shape of mask is [tokens, num_experts], it indicates which experts were selected
    by each token.

    When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
    expert capacity. This function exploits this feature to use ops that support cuda graph.

    Args:
        tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
        routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
        probs (torch.Tensor, optional): The probs tensor, [num_tokens, num_experts].
        num_out_tokens (int, optional): The number of output tokens. If None, it's set to
                                        the number of input tokens.
        fused (bool, optional): Whether use the fused permute function.
        drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
                                       and pads the number of tokens to the expert capacity.
                                       If set to true, routing_map has a fixed number of non-zeros
                                       in each column.
    """
    if fused and probs is None:
        if not HAVE_TE or fused_permute is None:
            raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
        permuted_input, sorted_indices = fused_permute(
            tokens, routing_map, num_out_tokens=num_out_tokens
        )
        return permuted_input, None, sorted_indices

    if fused and probs is not None:
        if not HAVE_TE or fused_permute_with_probs is None:
            raise ValueError(
                "fused_permute_with_probs is not available. Please install TE >= 2.1.0."
            )
        return fused_permute_with_probs(tokens, probs, routing_map, num_out_tokens=num_out_tokens)

    num_tokens, hidden = tokens.shape
    num_experts = routing_map.shape[1]
    permuted_probs = None
    if drop_and_pad and not (num_out_tokens is None):
        capacity = num_out_tokens // num_experts
        assert not routing_map.requires_grad
        # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
        routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
        # use argsort to put indices of all non-zeros in the beginning of list
        # and keep the first `capacity` number of indices
        sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
            :, :capacity
        ].contiguous()
        # flatten from [num_experts, capacity] to 1D
        sorted_indices = sorted_indices.view(-1)

        if probs is not None:
            # [num_tokens, num_experts] -> num_experts * num_tokens
            probs_T_1D = probs.T.contiguous().view(-1)
            # get 1D indices of the probs selected by routing_map
            indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
            indices_dim1 = sorted_indices.view(num_experts, capacity)
            indices_1D = (indices_dim0 * num_tokens + indices_dim1).view(-1)
            # get probs from indices
            permuted_probs = probs_T_1D.index_select(0, indices_1D)
    else:
        # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
        routing_map = routing_map.bool().T.contiguous()

        # Create a dense expert-to-token mapping from the sparse token-to-expert mapping
        token_indices = (
            torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
        )
        sorted_indices = token_indices.masked_select(routing_map)

        if probs is not None:
            permuted_probs = probs.T.contiguous().masked_select(routing_map)

    # use the mapping to permute the tokens
    permuted_input = tokens.index_select(0, sorted_indices)

    return permuted_input, permuted_probs, sorted_indices


def unpermute(
    permuted_tokens: torch.Tensor,
    sorted_indices: torch.Tensor,
    restore_shape: torch.Size,
    probs: torch.Tensor = None,
    routing_map: torch.Tensor = None,
    fused: bool = False,
    drop_and_pad: bool = False,
):
    """
    Restore the original order of tokens after permutation. If probs are provided, it
    will also apply them to the tokens before restoring the order.

    When drop_and_pad=True, the tensors will have the following properties:
      - In routing_map, the number of non-zeros in each column equals to expert capacity
      - The size of sorted_indices equals to num_experts * capacity, each split of `capacity`
        contains the indices of tokens routed to an expert.
    This function exploits these features to use ops that support cuda graph.

    Args:
        permuted_tokens (torch.Tensor): The permuted token tensor.
        sorted_indices (torch.Tensor): The indices used to sort the tokens.
        restore_shape (torch.Size): The shape of the unpermuted tensor.
        probs (torch.Tensor, optional): The unpermuted probs tensor,
        routing_map (torch.Tensor, optional): Token to expert mapping, shape
            [num_tokens, num_experts].
        fused (bool, optional): Whether use the fused unpermute function.
        drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
                                       and pads the number of tokens to the expert capacity.

    Returns:
        torch.Tensor: The tokens restored to their original order.
    """
    if fused:
        if not HAVE_TE or fused_unpermute is None:
            raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
        return fused_unpermute(
            permuted_tokens, sorted_indices, merging_probs=probs, restore_shape=restore_shape
        )

    _, hidden = restore_shape
    input_dtype = permuted_tokens.dtype

    if probs is not None:
        assert routing_map is not None, "Mask must be provided to permute the probs."
        if drop_and_pad:
            num_experts = routing_map.size(1)
            num_permuted_tokens = sorted_indices.size(0)
            capacity = num_permuted_tokens // num_experts
            num_unpermuted_tokens = probs.size(0)

            # [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
            probs_T_1D = probs.T.contiguous().view(-1)

            # get 1D indices of the probs selected by routing_map
            indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
            indices_dim1 = sorted_indices.view(num_experts, capacity)
            indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)

            # get probs from indices
            permuted_probs = probs_T_1D.index_select(0, indices_1D)
        else:
            permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
        # Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in
        # higher precision due to moe_router_dtype being enabled. This can lead to
        # additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory
        # allocation.
        permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)

    # Create an output tensor filled with zeros
    output_tokens = torch.zeros(
        restore_shape, dtype=permuted_tokens.dtype, device=permuted_tokens.device
    )
    # Scatter add the permuted_input back to the original positions
    output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
    return output_tokens.to(dtype=input_dtype)


def sort_chunks_by_idxs(
    input: torch.Tensor,
    split_sizes: torch.Tensor,
    sorted_idxs: torch.Tensor,
    probs: Optional[torch.Tensor] = None,
    fused: bool = False,
):
    """Split and sort the input tensor based on the split_sizes and sorted indices."""
    if fused and probs is None:
        if not HAVE_TE or fused_sort_chunks_by_index is None:
            raise ValueError(
                "fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0."
            )
        return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs), None

    if fused and probs is not None:
        if not HAVE_TE or fused_sort_chunks_by_index_with_probs is None:
            raise ValueError(
                "fused_sort_chunks_by_index_with_probs is not available. "
                "Please install TE >= 2.1.0."
            )
        return fused_sort_chunks_by_index_with_probs(input, probs, split_sizes, sorted_idxs)

    input = torch.split(input, split_sizes.tolist(), dim=0)
    output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
    if probs is not None:
        probs = torch.split(probs, split_sizes.tolist(), dim=0)
        permuted_probs = torch.cat([probs[i] for i in sorted_idxs.tolist()], dim=0)
    else:
        permuted_probs = None
    return output, permuted_probs


def group_limited_topk(
    scores: torch.Tensor,
    topk: int,
    num_tokens: int,
    num_experts: int,
    num_groups: int,
    group_topk: int,
):
    """Perform top-k routing on a subset of expert groups.

    When using group-limited routing:
    1. Experts are divided into 'moe_router_num_groups' equal-sized groups
    2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
       (specifically, the sum of top-2 expert scores within each group)
    3. From these selected groups, 'moe_router_topk' individual experts are chosen

    Two common use cases:
    - Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
      to limit each token to experts on a subset of devices
      (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)

    - Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
      to limit each token to experts on a subset of nodes
      (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)

    Args:
        scores (torch.Tensor): Softmax scores generated by the router.
        topk (int): The number of experts to select for each token.
        num_tokens (int): The number of tokens.
        num_experts (int): The number of experts.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of groups selected for each token.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor.
    """
    # Organize the experts into groups
    # Select groups based on sum of top-(topk/group_topk) routing scores within each group
    group_scores = (
        scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1)
    )
    group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
    group_mask = torch.zeros_like(group_scores)
    group_mask.scatter_(1, group_idx, 1)

    # Mask the experts based on selection groups
    score_mask = (
        group_mask.unsqueeze(-1)
        .expand(num_tokens, num_groups, num_experts // num_groups)
        .reshape(num_tokens, -1)
    )

    masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
    probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)

    return probs, top_indices


def pad_routing_map(routing_map: torch.Tensor, pad_multiple: int) -> torch.Tensor:
    """Pad the routing map to ensure each expert has a multiple of pad_multiple tokens.

    This function ensures that each expert has a number of tokens that is a multiple of
    pad_multiple by converting some 0s to 1s in the routing map. The padding is done by
    selecting the first N zero elements in each row, where N is the number needed to reach
    the next multiple of pad_multiple.

    Args:
        routing_map (torch.Tensor): A boolean or integer tensor of shape [num_tokens,
            num_experts] indicating which tokens are routed to which experts.
        pad_multiple (int): The multiple to pad each expert's token count to.

    Returns:
        torch.Tensor: The padded routing map of shape [num_tokens, num_experts].
    """
    # Transpose to [num_experts, num_tokens] for easier row-wise operations
    routing_map = routing_map.transpose(0, 1)  # [num_experts, num_tokens]

    # Calculate how many tokens need to be padded for each expert
    num_ones = routing_map.sum(dim=1)
    num_to_pad = (-num_ones) % pad_multiple

    # Find the positions of zeros in each row and their ranks
    is_zero = routing_map == 0
    zero_ranks = torch.cumsum(is_zero.int(), dim=1)

    # Create mask for elements that need to be padded (converted from 0 to 1)
    mask = zero_ranks <= num_to_pad.unsqueeze(1)
    routing_map[mask] = 1

    routing_map = routing_map.transpose(0, 1)
    return routing_map


def topk_softmax_with_capacity(
    logits: torch.Tensor,
    topk: int,
    capacity_factor: Optional[float] = None,
    pad_to_capacity: bool = False,
    drop_policy: str = "probs",
    use_pre_softmax: bool = False,
    num_groups: Optional[int] = None,
    group_topk: Optional[int] = None,
    scaling_factor: Optional[float] = None,
    deterministic_mode: bool = False,
    score_function: str = "softmax",
    expert_bias: Optional[torch.Tensor] = None,
):
    """Apply capacity and padding to the top-k selection.
    Args:
        logits (torch.Tensor): Logits tensor.
        topk (int): The number of experts to select for each token.
        capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
                               of tokens exceeds the capacity.
        pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded
                               tokens will be 0.
        drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
                           If "prob", the tokens with the lowest probabilities will be dropped.
                           If "position", tokens at the end of each batch will be dropped.
        use_pre_softmax (bool): Whether to apply softmax or sigmoid before top-k selection.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of selected groups for each token.
        scaling_factor (float): Scaling factor of routing score in top-k selection.
        deterministic_mode (bool): Deprecated.
        score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
        expert_bias (torch.Tensor): The bias added to logits for expert routing.
    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
              the routing probabilities for each token to each expert.
            - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
              indicating which experts were selected for each token. True values represent
              the selected experts.
            - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
              the number of local tokens assigned to each expert before dropping and padding.
    """
    assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
    num_tokens, num_experts = logits.shape

    def compute_topk(scores, topk, num_groups=None, group_topk=None):
        if group_topk:
            return group_limited_topk(
                scores=scores,
                topk=topk,
                num_tokens=num_tokens,
                num_experts=num_experts,
                num_groups=num_groups,
                group_topk=group_topk,
            )
        else:
            return torch.topk(scores, k=topk, dim=1)

    if score_function == "softmax":
        if use_pre_softmax:
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
            probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        else:
            scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
            probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
    elif score_function == "sigmoid":
        scores = torch.sigmoid(logits.float()).type_as(logits)
        if expert_bias is not None:
            scores_for_routing = scores + expert_bias
            _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
            scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
        else:
            scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
    else:
        raise ValueError(f"Invalid score_function: {score_function}")

    if scaling_factor:
        probs = probs * scaling_factor

    # TODO Try using element-wise operations instead of scatter?
    topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
    topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
    tokens_per_expert = topk_map.sum(dim=0)

    if capacity_factor is None:
        # TopK without capacity
        return topk_masked_gates, topk_map, tokens_per_expert
    else:
        # TopK with capacity
        expert_capacity = get_capacity(
            num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
        )

        # Maskout exceeded tokens
        if drop_policy == "probs":
            _, capacity_indices = torch.topk(
                topk_masked_gates, k=expert_capacity, dim=0, sorted=False
            )
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        elif drop_policy == "position":
            _, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False)
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        else:
            raise ValueError(f"Invalid drop_policy: {drop_policy}")

        if pad_to_capacity:
            final_map = capacity_mask
            final_probs = topk_masked_gates * final_map
        else:
            # Get exceed mask and maskout exceeded probs and indices
            final_map = torch.logical_and(topk_map, capacity_mask)
            final_probs = topk_masked_gates * final_map
        return final_probs, final_map, tokens_per_expert


def save_to_aux_losses_tracker(
    name: str,
    loss: torch.Tensor,
    layer_number: int,
    num_layers: int,
    reduce_group: torch.distributed.ProcessGroup = None,
    avg_group: torch.distributed.ProcessGroup = None,
):
    """Save the auxiliary loss for logging.
    Args:
        name (str): The name of the loss.
        loss (torch.Tensor): The loss tensor.
        layer_number (int): Layer index of the loss.
        num_layers (int): The number of total layers.
        reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
        mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
    """
    # Skip aux loss logging if layer_number is None.
    if layer_number is None:
        return

    tracker = get_moe_layer_wise_logging_tracker()
    if name not in tracker:
        tracker[name] = {}
        tracker[name]["values"] = torch.zeros(num_layers, device=loss.device)
    tracker[name]["values"][layer_number - 1] += loss.detach()  # Aggregate the loss for the layer.
    tracker[name]["reduce_group"] = reduce_group
    tracker[name]["avg_group"] = avg_group


def clear_aux_losses_tracker():
    """Clear the auxiliary losses."""
    tracker = get_moe_layer_wise_logging_tracker()
    for name in tracker:
        tracker[name]["values"].zero_()
        tracker[name]["reduce_group"] = None
        tracker[name]["avg_group"] = None


def reduce_aux_losses_tracker_across_ranks(track_names: Optional[List[str]] = None):
    """Collect and reduce the auxiliary losses across ranks."""
    tracker = get_moe_layer_wise_logging_tracker()
    if track_names is None:
        track_names = tracker.keys()
    for name in track_names:
        values = tracker[name]["values"]
        # TODO(Hepteract): delete the usage of the global parallel_state.
        # Collect aux losses across PP.
        torch.distributed.all_reduce(
            values, group=parallel_state.get_pipeline_model_parallel_group()
        )
        # Reduce aux losses across ranks.
        if tracker[name].get('reduce_group') is not None:
            torch.distributed.all_reduce(values, group=tracker[name].get('reduce_group'))
        if tracker[name].get('avg_group') is not None:
            torch.distributed.all_reduce(
                values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG
            )


def track_moe_metrics(
    loss_scale: float,
    iteration: int,
    writer,
    wandb_writer=None,
    total_loss_dict=None,
    per_layer_logging=False,
    force_initialize: bool = False,
    track_names: Optional[List[str]] = None,
    num_layers: Optional[int] = None,
    moe_layer_freq: Optional[Union[int, List[int]]] = None,
    mtp_num_layers: Optional[int] = None,
    config=None,
):
    """Track the MoE metrics for logging."""
    global _EXPERT_COUPLING_CONFIG
    if (
        _EXPERT_COUPLING_CONFIG
        and _EXPERT_COUPLING_CONFIG.moe_expert_coupling_analysis
        and _EXPERT_COUPLING_CONFIG.moe_expert_coupling_analysis_path
    ):
        save_expert_coupling_stats(
            _EXPERT_COUPLING_CONFIG.moe_expert_coupling_analysis_path,
            iteration,
            _EXPERT_COUPLING_CONFIG.num_moe_experts,
            _EXPERT_COUPLING_CONFIG.moe_router_topk,
        )

    # Aux loss logging
    tracker = get_moe_layer_wise_logging_tracker()
    # Initialize the tracker if force_initialize is True
    if force_initialize:
        if track_names is not None:
            for key in track_names:
                if key not in tracker:
                    tracker[key] = {}
                    tracker[key]["values"] = torch.zeros(num_layers, device="cuda")
                    tracker[key]["reduce_group"] = None
                    tracker[key]["avg_group"] = None
    reduce_aux_losses_tracker_across_ranks(track_names)

    # Get number of MoE layers
    if moe_layer_freq is None:
        num_moe_layers = num_layers
    elif isinstance(moe_layer_freq, int):
        assert isinstance(num_layers, int)
        moe_layer_pattern = [1 if (i % moe_layer_freq == 0) else 0 for i in range(num_layers)]
        num_moe_layers = sum(moe_layer_pattern)
    elif isinstance(moe_layer_freq, list):
        num_moe_layers = sum(moe_layer_freq)
    else:
        raise ValueError(f"Invalid moe_layer_freq: {moe_layer_freq}")

    if mtp_num_layers is not None:
        num_moe_layers += mtp_num_layers

    if writer is not None:
        aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()}
        for name, loss_list in aux_losses.items():
            if total_loss_dict is not None:
                if name not in total_loss_dict:
                    total_loss_dict[name] = loss_list.sum() / num_moe_layers
                else:
                    total_loss_dict[name] += loss_list.sum() / num_moe_layers

            # currently when using add_scalars,
            # torch.utils.add_scalars makes each timer its own run, which
            # polutes the runs list, so we just add each as a scalar
            writer.add_scalar(name, loss_list.sum() / num_moe_layers, iteration)
            if per_layer_logging:
                for i, loss in enumerate(loss_list.tolist()):
                    writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)

            # W&B logging lacks support for logging multiple scalars simultaneously.
            # As a workaround, we log each scalar individually first, then we can create
            # a custom panel to manually group them to a single plot.
            if wandb_writer:
                wandb_writer.log({f"{name}": loss_list.sum() / num_moe_layers - 1e-4}, iteration)
                if per_layer_logging:
                    wandb_writer.log(
                        {
                            f"moe/{name}_layer_{i}": loss
                            for i, loss in enumerate(loss_list.tolist())
                        },
                        iteration,
                    )

    clear_aux_losses_tracker()
    clear_expert_coupling_stats()


def get_updated_expert_bias(tokens_per_expert, expert_bias, expert_bias_update_rate):
    """Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#

    Args:
        tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert.
        expert_bias (torch.Tensor): The bias for each expert.
        expert_bias_udpate_rate (float): The update rate for the expert bias.
    """
    with torch.no_grad():
        # All Reduce Across TPxCPxDP group
        torch.distributed.all_reduce(
            tokens_per_expert,
            # TODO(Hepteract): delete the usage of the global parallel_state.
            group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
        )
        average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1]
        offset = average_tokens - tokens_per_expert
        updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate
        return updated_expert_bias


def maybe_move_tensor_to_cpu(tensor, as_numpy=False, record_stream=False):
    """Move a tensor to CPU if it is on GPU.
    Args:
        tensor (torch.Tensor or None): The tensor to move to CPU.
        as_numpy (bool): Whether to convert the tensor to a numpy array.
        record_stream (bool): Whether to record the stream of the tensor, to prevent memory leak
                              when the DtoH data transfer is on a side stream.
    """
    if torch.is_tensor(tensor) and tensor.is_cuda:
        cpu_tensor = tensor.to(torch.device("cpu"), non_blocking=True)
        if as_numpy:
            cpu_tensor = cpu_tensor.numpy()
        if record_stream:
            tensor.record_stream(torch.cuda.current_stream())
        tensor = cpu_tensor
    return tensor


def get_moe_layer_wise_logging_tracker():
    """Return the moe layer wise tracker."""
    global _MOE_LAYER_WISE_LOGGING_TRACKER
    return _MOE_LAYER_WISE_LOGGING_TRACKER


class RandomSTE(torch.autograd.Function):
    """
    Straight-Through Estimator(STE) function that returns random values
    with different seed for each rank.

    This is used to generate random logits of router for load-balanced benchmark.
    """

    generator = None

    @staticmethod
    def forward(ctx, logits):
        """
        Forward pass returns random logits with rank-specific seed.
        """
        if RandomSTE.generator is None:
            global_rank = torch.distributed.get_rank()
            base_seed = 42
            seed = base_seed + global_rank
            RandomSTE.generator = torch.Generator(device=logits.device)
            RandomSTE.generator.manual_seed(seed)

        random_logits = logits.clone().normal_(generator=RandomSTE.generator)
        return random_logits

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass propagates the gradient for logits.
        """
        return grad_output


def apply_random_logits(logits):
    """
    Apply the RandomSTE function to the logits.
    """
    return RandomSTE.apply(logits)


class RouterGatingLinearFunction(torch.autograd.Function):
    """
    Autograd function for router gating linear.
    """

    @staticmethod
    def forward(ctx, inp: torch.Tensor, weight: torch.Tensor, router_dtype: torch.dtype):
        """
        Forward pass of the RouterGatingLinearFunction function.
        """
        ctx.save_for_backward(inp, weight)
        ctx.router_dtype = router_dtype
        ctx.input_dtype = inp.dtype
        ctx.weight_dtype = weight.dtype
        inp_shape = inp.shape
        inp = inp.view(-1, inp_shape[-1])

        if te_general_gemm is not None and router_dtype != torch.float64:
            output = te_general_gemm(weight, inp, router_dtype, layout="TN")
            output = output[0]
        else:
            output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t())

        output = output.view(*inp_shape[:-1], -1)
        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """
        Backward pass of the RouterGatingLinearFunction function.
        """
        inp, weight = ctx.saved_tensors
        inp_shape = inp.shape
        grad_shape = grad_output.shape
        inp = inp.view(-1, inp_shape[-1])
        grad_output = grad_output.view(-1, grad_shape[-1])

        if te_general_gemm is not None and ctx.router_dtype != torch.float64:
            grad_input = te_general_gemm(
                weight.to(ctx.router_dtype), grad_output, ctx.router_dtype, layout="NN", grad=True
            )
            grad_weight = te_general_gemm(
                inp.to(ctx.router_dtype), grad_output, ctx.router_dtype, layout="NT", grad=True
            )
            grad_input = grad_input[0].to(ctx.input_dtype)
            grad_weight = grad_weight[0].to(ctx.weight_dtype)
        else:
            grad_input = torch.mm(grad_output, weight.to(ctx.router_dtype)).to(ctx.input_dtype)
            grad_weight = torch.mm(grad_output.t(), inp.to(ctx.router_dtype)).to(ctx.weight_dtype)

        grad_input = grad_input.view(*inp_shape)
        return grad_input, grad_weight, None


def router_gating_linear(inp: torch.Tensor, weight: torch.Tensor, router_dtype: torch.dtype):
    """
    Customized linear layer for router gating.
    This linear layer accepts bfloat16 input and weight, and can return output with router_dtype.
    It can reduce the memory usage by avoiding saving the intermediate high precision tensors.
    """
    return RouterGatingLinearFunction.apply(inp, weight, router_dtype)


# TODO(Hepteract): delete the usage of the global parallel_state.
# Initialize process groups with the global parallel_state.
def get_default_model_comm_pgs():
    """Get the default process groups for MoE.

    Returns:
        ModelCommProcessGroups: The default process groups for MoE.
    """
    model_comm_pgs = ModelCommProcessGroups()
    model_comm_pgs.ep = parallel_state.get_expert_model_parallel_group()
    model_comm_pgs.tp = parallel_state.get_tensor_model_parallel_group()
    model_comm_pgs.cp = parallel_state.get_context_parallel_group()
    model_comm_pgs.expt_tp = parallel_state.get_expert_tensor_parallel_group()
    model_comm_pgs.expt_dp = parallel_state.get_expert_data_parallel_group()
    model_comm_pgs.tp_ep = parallel_state.get_expert_tensor_and_model_parallel_group()
    model_comm_pgs.tp_cp = parallel_state.get_tensor_and_context_parallel_group()
    return model_comm_pgs


def clear_expert_coupling_stats():
    """Clear the expert coupling stats."""
    global _EXPERT_COUPLING_STATS, _LAST_ROUTING_MAP_ACROSS_LAYERS, _EXPERT_COUPLING_CONFIG
    _EXPERT_COUPLING_STATS.clear()
    _LAST_ROUTING_MAP_ACROSS_LAYERS.clear()
    _EXPERT_COUPLING_CONFIG = None


def update_expert_coupling_stats(layer_number: int, current_routing_map: torch.Tensor, config):
    """Update the expert coupling stats."""
    global _EXPERT_COUPLING_STATS, _LAST_ROUTING_MAP_ACROSS_LAYERS, _EXPERT_COUPLING_CONFIG

    if _EXPERT_COUPLING_CONFIG is None:
        _EXPERT_COUPLING_CONFIG = config

    # Determine if this is the first MoE layer in the model
    first_moe_layer = -1
    if isinstance(config.moe_layer_freq, int):
        if config.moe_layer_freq > 0:
            first_moe_layer = 1
    elif isinstance(config.moe_layer_freq, list):
        try:
            first_moe_layer = config.moe_layer_freq.index(1) + 1
        except ValueError:
            pass  # No MoE layers

    if layer_number == first_moe_layer:
        _LAST_ROUTING_MAP_ACROSS_LAYERS.clear()

    prev_layer_number = -1
    for l_num in sorted(_LAST_ROUTING_MAP_ACROSS_LAYERS.keys(), reverse=True):
        if l_num < layer_number:
            prev_layer_number = l_num
            break

    if prev_layer_number != -1:
        prev_routing_map = _LAST_ROUTING_MAP_ACROSS_LAYERS[prev_layer_number]

        num_tokens = min(prev_routing_map.shape[0], current_routing_map.shape[0])
        prev_map_slice = prev_routing_map[:num_tokens, :].float()
        curr_map_slice = current_routing_map[:num_tokens, :].float()

        co_occurrence = torch.matmul(prev_map_slice.T, curr_map_slice)
        prev_counts = prev_map_slice.sum(dim=0)

        key = (prev_layer_number, layer_number)
        if key not in _EXPERT_COUPLING_STATS:
            _EXPERT_COUPLING_STATS[key] = [
                torch.zeros_like(co_occurrence, device='cpu'),
                torch.zeros_like(prev_counts, device='cpu'),
            ]

        _EXPERT_COUPLING_STATS[key][0] += co_occurrence.to('cpu')
        _EXPERT_COUPLING_STATS[key][1] += prev_counts.to('cpu')

    _LAST_ROUTING_MAP_ACROSS_LAYERS[layer_number] = current_routing_map.detach()


def save_expert_coupling_stats(output_path: str, iteration: int, num_experts: int, top_k: int):
    """Save the expert coupling stats to a file."""
    if not _EXPERT_COUPLING_STATS:
        return

    import os
    import numpy as np

    if parallel_state.get_data_parallel_rank() == 0:
        os.makedirs(output_path, exist_ok=True)

    for (prev_layer, curr_layer), (
        co_occurrence,
        prev_counts,
    ) in _EXPERT_COUPLING_STATS.items():
        co_occurrence_gpu = co_occurrence.to(torch.cuda.current_device())
        prev_counts_gpu = prev_counts.to(torch.cuda.current_device())
        torch.distributed.all_reduce(co_occurrence_gpu, group=parallel_state.get_data_parallel_group())
        torch.distributed.all_reduce(prev_counts_gpu, group=parallel_state.get_data_parallel_group())

        if parallel_state.get_data_parallel_rank() == 0:
            co_occurrence = co_occurrence_gpu.cpu()
            prev_counts = prev_counts_gpu.cpu()
            prev_counts[prev_counts == 0] = 1.0
            prob_matrix = co_occurrence / prev_counts.unsqueeze(1)
            filename = os.path.join(
                output_path, f'coupling_iter_{iteration}_layers_{prev_layer}-{curr_layer}.txt'
            )
            np.savetxt(filename, prob_matrix.numpy(), fmt='%.4f')
