# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

from abc import ABC, abstractmethod
from typing import Optional, Dict

import torch
import math
import logging
import torch.distributed as dist

from megatron.core.tensor_parallel import (
    reduce_from_tensor_model_parallel_region,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core import parallel_state
from megatron.core.transformer.moe.moe_utils import (
    MoEAuxLossAutoScaler,
    apply_random_logits,
    apply_router_token_dropping,
    compute_routing_scores_for_aux_loss,
    router_gating_linear,
    save_to_aux_losses_tracker,
    router_metrics_push,
    sinkhorn,
    get_capacity,
    switch_load_balancing_loss_func,
    topk_routing_with_score_function,
    z_loss_func,
)
from megatron.core.process_groups_config import ProcessGroupCollection
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Dirichlet
from megatron.core.transformer.transformer_config import TransformerConfig


# -----------------------------
# Metrics helpers (distributed-safe)
# -----------------------------
def _default_group():
    try:
        return parallel_state.get_tensor_and_context_parallel_group()
    except Exception:
        return None


def allreduce_(tensor: torch.Tensor, group=None, op=dist.ReduceOp.SUM):
    if group is not None and dist.is_available() and dist.is_initialized():
        dist.all_reduce(tensor, group=group, op=op)
    return tensor


def global_scalar(value: float, device) -> torch.Tensor:
    return torch.tensor(float(value), device=device, dtype=torch.float32)


# -----------------------------
# §2: Efficiency / timing
# -----------------------------
from contextlib import contextmanager


@contextmanager
def cuda_timer():
    """Simple CUDA timer. Use as: with cuda_timer() as t: ...; ms = t.ms()"""
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    yield type("T", (), {"ms": lambda: (end.record(), end.synchronize(), start.elapsed_time(end))[2]})


def step_efficiency_metrics(
    step_tokens: int,
    step_time_s: float,
    a2a_ms: Optional[float] = None,
    expert_ms: Optional[float] = None,
) -> Dict[str, torch.Tensor]:
    """Compute throughput and (optional) kernel breakdown."""
    tokens_per_s = step_tokens / max(1e-6, step_time_s)
    out = {
        "eff_throughput_tokens_per_s": torch.tensor(tokens_per_s),
        "eff_step_time_s": torch.tensor(step_time_s),
    }
    if a2a_ms is not None:
        out["eff_alltoall_ms"] = torch.tensor(a2a_ms)
    if expert_ms is not None:
        out["eff_expert_ms"] = torch.tensor(expert_ms)
    return out


# -----------------------------
# §3: Sparsity & calibration
# -----------------------------
def simpson_index_r(r_probs: torch.Tensor) -> torch.Tensor:
    """
    Simpson index H = sum_i r_i^2 averaged over batch.
    r_probs: [N, E] routing weights on simplex.
    """
    return (r_probs.pow(2).sum(dim=-1)).mean()


def routing_cardinality_stats(routing_map: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    routing_map: [N, E] boolean map of executed experts per token.
    Returns avg and max active experts per token.
    """
    per_row = routing_map.sum(dim=1).to(torch.float32)
    return {
        "k_avg": per_row.mean(),
        "k_max": per_row.max().to(torch.float32),
    }


def expected_k_error(z_soft: torch.Tensor, target_k: float) -> torch.Tensor:
    """
    z_soft: [N, E] relaxed inclusion gates in (0,1).
    Penalize (sum_i z_i - k)^2 averaged over tokens.
    """
    s = z_soft.sum(dim=-1) - target_k
    return (s * s).mean()


def leakage_metrics(
    r_probs: torch.Tensor,
    z_scores: torch.Tensor,
    k: int,
    use_topk_on: str = "z",
) -> Dict[str, torch.Tensor]:
    """
    Estimate mass on active set vs inactive set.
    use_topk_on: "z" (gates) or "r" (weights) to define active indices.
    Returns: mass_on_active (mean over batch), leak = 1 - mass_on_active
    """
    if k <= 0:
        return {"mass_active_mean": torch.tensor(0.0, device=r_probs.device), "leak_mean": torch.tensor(1.0, device=r_probs.device)}
    if use_topk_on == "z":
        idx = z_scores.topk(k, dim=-1).indices
    else:
        idx = r_probs.topk(k, dim=-1).indices
    active_mask = torch.zeros_like(r_probs, dtype=torch.bool)
    active_mask.scatter_(1, idx, True)
    mass_active = (r_probs * active_mask.to(r_probs.dtype)).sum(dim=-1)  # [N]
    mass_inactive = 1.0 - mass_active
    return {
        "mass_active_mean": mass_active.mean(),
        "leak_mean": mass_inactive.mean(),
    }


# -----------------------------
# §4: No-LB advantage (balance, capacity)
# -----------------------------
def global_load_and_importance(
    routing_map: torch.Tensor,   # [N,E] bool
    r_probs: torch.Tensor,       # [N,E] weights (not necessarily per-row normalized)
    group=None,
) -> Dict[str, torch.Tensor]:
    """
    Compute GLOBAL load fractions f_i and importance fractions p_i.
    - f_i: fraction of routed tokens to expert i over TOTAL routed tokens across all experts
           denom = sum_i tokens_per_expert_i = N_global * avg_k (robust to variable-k)
    - p_i: fraction of routing mass assigned to expert i over TOTAL routing mass
           denom = sum_{tokens,experts} r_probs
    """
    if group is None:
        group = _default_group()
    device = r_probs.device
    # Mask routing weights to executed edges and renormalize per token to stay on simplex
    exec_mass = r_probs * routing_map.to(r_probs.dtype)
    token_mass = exec_mass.sum(dim=-1, keepdim=True)
    exec_mass = torch.where(token_mass > 0, exec_mass / token_mass.clamp_min(1e-6), torch.zeros_like(exec_mass))

    # Local tallies (executed tokens and executed routing mass)
    tokens_per_expert = routing_map.sum(dim=0).to(torch.float32).clone()
    mass_per_expert = exec_mass.sum(dim=0).to(torch.float32).clone()

    # Global reduce in-place (group may be None for local metrics)
    allreduce_(tokens_per_expert, group)
    allreduce_(mass_per_expert, group)

    total_routed = tokens_per_expert.sum().clamp_min(1e-6)
    total_mass = mass_per_expert.sum().clamp_min(1e-6)
    f = tokens_per_expert / total_routed
    p = mass_per_expert / total_mass
    return {"f": f, "p": p, "total_routed": total_routed, "total_mass": total_mass}


def switch_load_balance_loss(
    routing_map: torch.Tensor,
    r_probs: torch.Tensor,
    group=None,
) -> torch.Tensor:
    """
    Switch-style LB loss: L = E * sum_i f_i * p_i  (Fedus et al., 2021).
    routing_map: [N,E] bool
    r_probs:     [N,E] probs on simplex
    """
    if group is None:
        group = _default_group()
    stats = global_load_and_importance(routing_map, r_probs, group)
    f, p = stats["f"], stats["p"]
    E = f.numel()
    return torch.tensor(E, device=f.device, dtype=f.dtype) * (f * p).sum()


def capacity_stats(
    routing_map: torch.Tensor,      # [N,E] bool
    capacity_factor: float,
    k: int,
    group=None,
) -> Dict[str, torch.Tensor]:
    """
    Drop-rate and padding-waste under capacity control.
    capacity per expert (global): cap = ceil(CF * (N_global * k) / E).
    drop_rate = (#excess tokens over cap) / (N_global * k).
    padding_waste = 1 - sum(min(count_i, cap)) / (cap * E).
    """
    if group is None:
        group = _default_group()
    device = routing_map.device
    E = routing_map.size(1)

    tokens_per_expert_local = routing_map.sum(dim=0).to(torch.float32)  # [E]
    tokens_per_expert = tokens_per_expert_local.clone()
    N = global_scalar(float(routing_map.size(0)), device)
    allreduce_(tokens_per_expert, group)
    allreduce_(N, group)
    N_global = float(N.item())

    total_routed = N_global * max(1, k)
    cap = math.ceil(float(capacity_factor) * total_routed / max(1, E))
    cap_t = torch.tensor(cap, dtype=torch.float32, device=device)

    excess = torch.clamp(tokens_per_expert - cap_t, min=0)
    drop_rate = excess.sum() / max(1.0, total_routed)
    used = torch.clamp(tokens_per_expert, max=cap_t).sum()
    denom = cap_t * E
    padding_waste = 1.0 - (used / (denom + 1e-6))

    return {
        "cap_per_expert": cap_t,
        "drop_rate": drop_rate,
        "padding_waste": padding_waste.to(torch.float32),
        "tokens_per_expert_mean": tokens_per_expert.mean(),
    }


class _RMSNormLocal(torch.nn.Module):
    """Lightweight RMSNorm to avoid import cycles.

    Matches legacy RMSNorm behavior: y = weight * x / rms(x), with eps and SP flag.
    """

    def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(dim))
        setattr(self.weight, 'sequence_parallel', sequence_parallel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        y = x * rms
        return y * self.weight


class Router(ABC, MegatronModule):
    """Base Router class"""

    def __init__(
        self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None
    ) -> None:
        """
        Initialize the Router module.

        Args:
            config (TransformerConfig): Configuration object for the Transformer model.
            pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations.
        """
        super().__init__(config)
        self.config = config
        self.num_experts = self.config.num_moe_experts
        self.moe_aux_loss_func = None
        self.layer_number = None
        self.tp_group = pg_collection.tp
        self.cp_group = pg_collection.cp
        self.tp_cp_group = pg_collection.tp_cp
        self.tp_dp_cp_group = pg_collection.tp_dp_cp

        # Initialize the gate weights.
        # TODO: Add support for GPU initialization, which requires updating the golden values.
        self.weight = torch.nn.Parameter(
            torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
        )
        # Pre-gating normalization to stabilize router logits (RMSNorm)
        self.pre_gating_rmsnorm = _RMSNormLocal(
            dim=self.config.hidden_size,
            eps=getattr(self.config, 'moe_gating_rmsnorm_eps', 1e-6),
            sequence_parallel=self.config.sequence_parallel,
        )
        # If calculate per token loss, we need to scale up moe aux loss by the number of tokens.
        # So we need to know if the model is configured to calculate per token loss.
        self.calculate_per_token_loss = self.config.calculate_per_token_loss
        self.reset_parameters()

    def reset_parameters(self):
        """Reset the router parameters."""
        if self.config.perform_initialization:
            self.config.init_method(self.weight)
        self.weight.data = self.weight.data.to(dtype=self.config.params_dtype)
        setattr(self.weight, 'sequence_parallel', self.config.sequence_parallel)
        # Align RMSNorm parameter dtype and sequence_parallel flag
        if hasattr(self, 'pre_gating_rmsnorm') and self.pre_gating_rmsnorm is not None:
            self.pre_gating_rmsnorm.weight.data = self.pre_gating_rmsnorm.weight.data.to(
                dtype=self.config.params_dtype
            )
            setattr(self.pre_gating_rmsnorm.weight, 'sequence_parallel', self.config.sequence_parallel)

    def gating(self, input: torch.Tensor):
        """Forward pass of the router gate.

        Args:
            input (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Logits tensor.
        """
        # Ensure RMSNorm is on the correct device and normalize inputs before gating
        if hasattr(self, 'pre_gating_rmsnorm') and self.pre_gating_rmsnorm is not None:
            if self.pre_gating_rmsnorm.weight.device.type == 'cpu':
                self.pre_gating_rmsnorm.to(device=torch.cuda.current_device())
            input = self.pre_gating_rmsnorm(input)
        if self.weight.device.type == 'cpu':
            # move weights to GPU
            self.weight.data = self.weight.data.to(device=torch.cuda.current_device())
        # Convert to specified datatype for routing computation if enabled
        router_dtype = input.dtype
        if self.config.moe_router_dtype == 'fp32':
            router_dtype = torch.float32
        elif self.config.moe_router_dtype == 'fp64':
            router_dtype = torch.float64
        logits = router_gating_linear(input, self.weight, router_dtype)
        return logits

    @abstractmethod
    def routing(self, logits: torch.Tensor):
        """Routing function.

        Args:
            logits (torch.Tensor): Logits tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
            probabilities and mapping.
        """
        raise NotImplementedError("Routing function not implemented.")

    @abstractmethod
    def forward(self, input: torch.Tensor):
        """
        Forward pass of the router.

        Args:
            input (torch.Tensor): Input tensor.
        """
        raise NotImplementedError("Forward function not implemented.")

    def set_layer_number(self, layer_number: int):
        """Set the layer number for the router."""
        self.layer_number = layer_number

    def _maintain_float32_expert_bias(self):
        """Ensure expert_bias buffer stays in float32 if present."""
        if hasattr(self, 'expert_bias') and self.expert_bias is not None:
            if self.expert_bias.dtype != torch.float32:
                self.expert_bias.data = self.expert_bias.data.to(torch.float32)

    def _load_from_state_dict(self, *args, **kwargs):
        """Load the state dict of the router, ensuring proper dtypes for buffers."""
        self._maintain_float32_expert_bias()
        return super()._load_from_state_dict(*args, **kwargs)

    def _save_to_state_dict(self, *args, **kwargs):
        """Save the state dict of the router, ensuring proper dtypes for buffers."""
        self._maintain_float32_expert_bias()
        return super()._save_to_state_dict(*args, **kwargs)

    def apply_input_jitter(self, input: torch.Tensor):
        """Add noise to the input tensor. Refer to https://arxiv.org/abs/2101.03961."""
        if getattr(self.config, 'moe_input_jitter_eps', None) is not None:
            eps = self.config.moe_input_jitter_eps
            if not hasattr(self, 'input_jitter') or self.input_jitter is None:
                self.input_jitter = torch.distributions.uniform.Uniform(
                    torch.tensor(1.0 - eps, device=input.device),
                    torch.tensor(1.0 + eps, device=input.device),
                ).rsample
            return input * self.input_jitter(input.shape)
        return input

    def attach_and_log_load_balancing_loss(
        self,
        activation: torch.Tensor,
        aux_loss_coeff: float,
        aux_loss: torch.Tensor,
        aux_loss_name: str,
        reduce_group: torch.distributed.ProcessGroup,
    ):
        """Attach aux loss function to activation and add to logging.

        Matches new-style TopKRouter behavior and supports per-token-loss scaling.
        """
        num_layers = self.config.num_layers
        if getattr(self.config, 'mtp_num_layers', None) is not None:
            num_layers += self.config.mtp_num_layers
        save_to_aux_losses_tracker(
            aux_loss_name,
            aux_loss / aux_loss_coeff,
            self.layer_number,
            num_layers,
            reduce_group=reduce_group,
        )
        if self.calculate_per_token_loss:
            activation = MoEAuxLossAutoScaler.apply(activation, aux_loss * activation.shape[0])
        else:
            activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
        return activation


class TopKRouter(Router):
    """Route each token to the top-k experts.

    The workflow of TopKRouter is as follows:
    (1) Calculate the logits by the router gating network.
    (2) Calculate the routing probabilities and map for top-k selection with score function.
    (3) [Optional] Apply token dropping to top-k expert selection.
    (4) [Optional] Apply the auxiliary load balancing loss for the given scores and routing map.

    Naming convention:
        logits: The output logits by the router gating network.
        scores: The scores after score function used to select the experts and calculate aux loss.
        probs: The topk weights used to combined the experts' outputs.
        routing_map: The masked routing map between tokens and experts.
    """

    def __init__(
        self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None
    ) -> None:
        """Initialize the zero token dropping router.

        Args:
            config (TransformerConfig): The configuration for the transformer model.
            pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations.
        """
        super().__init__(config=config, pg_collection=pg_collection)
        self.topk = self.config.moe_router_topk
        self.routing_type = self.config.moe_router_load_balancing_type
        self.score_function = self.config.moe_router_score_function
        self.input_jitter = None

        self.enable_expert_bias = self.config.moe_router_enable_expert_bias
        if self.enable_expert_bias:
            self.register_buffer(
                'local_tokens_per_expert',
                torch.zeros(
                    self.config.num_moe_experts,
                    dtype=torch.float32,
                    device=torch.cuda.current_device(),
                ),
                persistent=False,
            )
            self.register_buffer(
                'expert_bias',
                torch.zeros(
                    self.config.num_moe_experts,
                    dtype=torch.float32,
                    device=torch.cuda.current_device(),
                ),
            )
        else:
            self.local_tokens_per_expert = None
            self.expert_bias = None

        # Initialize global tokens per expert for global aux loss
        if self.get_aux_loss_coeff("global_aux_loss") > 0:
            self.register_buffer(
                'global_tokens_per_expert',
                torch.zeros(
                    self.config.num_moe_experts,
                    dtype=torch.float32,
                    device=torch.cuda.current_device(),
                ),
                persistent=False,
            )
            self.register_buffer(
                'ga_steps',
                torch.tensor(0, dtype=torch.float32, device=torch.cuda.current_device()),
                persistent=False,
            )
        else:
            self.global_tokens_per_expert = None
            self.ga_steps = None

    def _maintain_float32_expert_bias(self):
        """
        Maintain the expert bias in float32.

        When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
        We keep it in float32 to avoid routing errors when updating the expert_bias.
        """
        if hasattr(self, 'expert_bias') and self.expert_bias is not None:
            if self.expert_bias.dtype != torch.float32:
                self.expert_bias.data = self.expert_bias.data.to(torch.float32)

    def sinkhorn_load_balancing(self, logits: torch.Tensor):
        """Apply sinkhorn routing to the logits tensor.

        Args:
            logits (torch.Tensor): The logits tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
            probabilities and mask.
        """

        def _sinkhorn_activation(logits):
            if self.topk == 1:
                logits = torch.sigmoid(logits)
            else:  # k > 1
                logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
            return logits

        assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
        if self.training:
            with torch.no_grad():
                norm_logits = sinkhorn(
                    logits.to(dtype=torch.float32)
                )  # explicit fp32 conversion for stability
                _, indices = torch.topk(norm_logits, k=self.topk, dim=1)
            logits = _sinkhorn_activation(logits)
        else:
            logits = _sinkhorn_activation(logits)
            _, indices = torch.topk(logits, k=self.topk, dim=1)
        map = torch.zeros_like(logits).int().scatter(1, indices, 1).bool()
        scores = logits * map
        return scores, map

    def get_aux_loss_coeff(self, aux_loss_type: str) -> float:
        """Return the aux loss coeff for the given auxiliary loss type.
        If the auxiliary loss type is not found, return 0.0.
        """
        if isinstance(self.routing_type, str):
            if self.routing_type == aux_loss_type:
                return self.config.moe_aux_loss_coeff
        if isinstance(self.routing_type, list):
            try:
                idx = self.routing_type.index(aux_loss_type)
                return self.config.moe_aux_loss_coeff[idx]
            except ValueError:
                return 0.0
        return 0.0

    def is_aux_loss_enabled(self) -> bool:
        """Check if the auxiliary loss is enabled."""
        for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]:
            if self.get_aux_loss_coeff(aux_loss_type) > 0:
                return True
        return False

    def _apply_aux_loss(
        self, probs: torch.Tensor, scores_for_aux_loss: torch.Tensor, routing_map: torch.Tensor
    ):
        """Apply the auxiliary loss for the given scores and routing map."""
        aux_loss_coeff = self.get_aux_loss_coeff("aux_loss")
        if aux_loss_coeff == 0:
            return probs
        tokens_per_expert = routing_map.sum(dim=0)
        tokens_per_expert = reduce_from_tensor_model_parallel_region(
            tokens_per_expert, self.tp_cp_group
        )
        num_tokens = routing_map.shape[0]
        total_num_tokens = num_tokens * self.tp_cp_group.size()

        aux_loss = switch_load_balancing_loss_func(
            probs=scores_for_aux_loss,
            tokens_per_expert=tokens_per_expert,
            total_num_tokens=total_num_tokens,
            topk=self.topk,
            num_experts=self.config.num_moe_experts,
            moe_aux_loss_coeff=aux_loss_coeff,
            fused=self.config.moe_router_fusion,
        )
        probs = self.attach_and_log_load_balancing_loss(
            probs, aux_loss_coeff, aux_loss, "load_balancing_loss", self.tp_cp_group
        )
        return probs

    def _apply_seq_aux_loss(
        self,
        probs: torch.Tensor,
        scores_for_aux_loss: torch.Tensor,
        routing_map: torch.Tensor,
        seq_length: int,
        bsz: int,
    ):
        """Apply the sequence-level auxiliary loss for the given scores and routing map.

        To calculate the sequence-level aux loss, we reshape the batch_size dimension to
        experts dimension. The resulted loss by switch_load_balancing_loss_func is equal
        to the sum of aux loss for each sequence in the batch. And then we divide the aux
        loss by the batch size to get averaged aux loss.
        """
        seq_aux_loss_coeff = self.get_aux_loss_coeff("seq_aux_loss")
        if seq_aux_loss_coeff == 0:
            return probs

        scores_for_aux_loss = scores_for_aux_loss.reshape(seq_length, -1)
        tokens_per_expert = routing_map.reshape(seq_length, -1).sum(dim=0)
        tokens_per_expert = reduce_from_tensor_model_parallel_region(
            tokens_per_expert, self.tp_cp_group
        )

        total_num_tokens = seq_length * self.tp_cp_group.size()

        aux_loss = (
            switch_load_balancing_loss_func(
                probs=scores_for_aux_loss,
                tokens_per_expert=tokens_per_expert,
                total_num_tokens=total_num_tokens,
                topk=self.topk,
                num_experts=self.config.num_moe_experts,
                moe_aux_loss_coeff=seq_aux_loss_coeff,
                fused=self.config.moe_router_fusion,
            )
            / bsz
        )
        probs = self.attach_and_log_load_balancing_loss(
            probs, seq_aux_loss_coeff, aux_loss, "seq_load_balancing_loss", self.tp_cp_group
        )
        return probs

    def _apply_global_aux_loss(
        self, probs: torch.Tensor, scores_for_aux_loss: torch.Tensor, routing_map: torch.Tensor
    ):
        """Apply the global auxiliary loss for the given scores and routing map."""
        global_aux_loss_coeff = self.get_aux_loss_coeff("global_aux_loss")
        if global_aux_loss_coeff == 0:
            return probs

        tokens_per_expert = routing_map.sum(dim=0)
        tokens_per_expert = reduce_from_tensor_model_parallel_region(
            tokens_per_expert, self.tp_dp_cp_group
        )

        self.global_tokens_per_expert += tokens_per_expert
        self.ga_steps += 1
        averated_tokens_per_expert = self.global_tokens_per_expert / self.ga_steps

        num_tokens = scores_for_aux_loss.shape[0]
        total_num_tokens = num_tokens * self.tp_dp_cp_group.size()

        global_aux_loss = switch_load_balancing_loss_func(
            probs=scores_for_aux_loss,
            tokens_per_expert=averated_tokens_per_expert,
            total_num_tokens=total_num_tokens,
            topk=self.topk,
            num_experts=self.config.num_moe_experts,
            moe_aux_loss_coeff=global_aux_loss_coeff,
            fused=self.config.moe_router_fusion,
        )
        probs = self.attach_and_log_load_balancing_loss(
            probs,
            global_aux_loss_coeff,
            global_aux_loss,
            "global_load_balancing_loss",
            self.tp_dp_cp_group,
        )
        return probs

    

    def apply_z_loss(self, logits):
        """Encourages the router's logits to remain small to enhance stability.
        Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

        Args:
            logits (torch.Tensor): The logits of the router.

        Returns:
            torch.Tensor: The logits after applying the z-loss.
        """
        if self.config.moe_z_loss_coeff is not None and self.training and torch.is_grad_enabled():
            # Skip Z loss calculations when using torch.no_grad() or checkpointing.
            moe_z_loss_coeff = self.config.moe_z_loss_coeff / self.tp_cp_group.size()
            z_loss = z_loss_func(logits, moe_z_loss_coeff)
            scale_up = 1.0
            if self.calculate_per_token_loss:
                # The expected final scaling for z_loss gradients is
                # 1/(num_micro_batches * dp_size).
                # After commit 02648000, Megatron started using the number of total tokens
                # to scale gradients under the argument of calculate_per_token_loss,
                # which scales both the main_loss gradient and z_loss gradient by
                # 1/(num_local_tokens * dp_size * num_micro_batches) in finalize_model_grads().
                # To correct this scaling, we need to scale the z_loss by num_local_tokens here.
                logits = MoEAuxLossAutoScaler.apply(logits, z_loss * logits.shape[0])
            else:
                logits = MoEAuxLossAutoScaler.apply(logits, z_loss)

            num_layers = self.config.num_layers
            if self.config.mtp_num_layers is not None:
                num_layers += self.config.mtp_num_layers
            save_to_aux_losses_tracker(
                "z_loss", z_loss / moe_z_loss_coeff, self.layer_number, num_layers
            )
        return logits

    def apply_input_jitter(self, input: torch.Tensor):
        """Add noise to the input tensor.
        Refer to https://arxiv.org/abs/2101.03961.

        Args:
            input (Tensor): Input tensor.

        Returns:
            Tensor: Jittered input.
        """
        if self.config.moe_input_jitter_eps is not None:
            eps = self.config.moe_input_jitter_eps
            if self.input_jitter is None:
                self.input_jitter = torch.distributions.uniform.Uniform(
                    torch.tensor(1.0 - eps, device=input.device),
                    torch.tensor(1.0 + eps, device=input.device),
                ).rsample
            return input * self.input_jitter(input.shape)
        else:
            return input

    def routing(self, logits: torch.Tensor):
        """Top-k routing function

        Args:
            logits (torch.Tensor): Logits tensor after gating.

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment.
            routing_map (torch.Tensor): The mapping of token to experts assignment,
                with shape [num_tokens, num_experts].
        """
        seq_length, bsz = logits.shape[:2]
        logits = logits.view(-1, self.config.num_moe_experts)

        # Apply Z-Loss
        logits = self.apply_z_loss(logits)

        # Calculate probs and routing_map for token dispatching
        if self.routing_type == "sinkhorn":
            probs, routing_map = self.sinkhorn_load_balancing(logits)
        else:
            probs, routing_map = topk_routing_with_score_function(
                logits,
                self.topk,
                use_pre_softmax=self.config.moe_router_pre_softmax,
                num_groups=self.config.moe_router_num_groups,
                group_topk=self.config.moe_router_group_topk,
                scaling_factor=self.config.moe_router_topk_scaling_factor,
                score_function=self.score_function,
                expert_bias=self.expert_bias,
                fused=self.config.moe_router_fusion,
            )

        # Apply token dropping to probs and routing_map.
        if self.config.moe_expert_capacity_factor is not None:
            probs, routing_map = apply_router_token_dropping(
                probs,
                routing_map,
                router_topk=self.topk,
                capacity_factor=self.config.moe_expert_capacity_factor,
                drop_policy=self.config.moe_token_drop_policy,
                pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
            )

        # Apply each aux loss type and attach aux loss autograd function to probs
        if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled():
            # Calculate scores and routing_map for aux loss
            routing_map_for_aux_loss, scores_for_aux_loss = compute_routing_scores_for_aux_loss(
                logits, self.topk, self.score_function, fused=self.config.moe_router_fusion
            )
            probs = self._apply_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss)
            probs = self._apply_seq_aux_loss(
                probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz
            )
            probs = self._apply_global_aux_loss(
                probs, scores_for_aux_loss, routing_map_for_aux_loss
            )

        # Update expert bias and tokens_per_expert
        # Prevent extra local tokens accumulation on evaluation or activation recomputation
        if self.enable_expert_bias and torch.is_grad_enabled():
            with torch.no_grad():
                self.local_tokens_per_expert += routing_map.sum(dim=0)

        return probs, routing_map

    def reset_global_aux_loss_tracker(self):
        """Reset the global aux loss tracker."""
        if self.global_tokens_per_expert is not None:
            self.global_tokens_per_expert.zero_()
            self.ga_steps.zero_()

    def forward(self, input: torch.Tensor):
        """
        Forward pass of the router.

        Args:
            input (torch.Tensor): Input tensor.
        """
        self._maintain_float32_expert_bias()

        # Apply input jitter
        input = self.apply_input_jitter(input)
        logits = self.gating(input)

        if self.config.moe_router_force_load_balancing:
            # Apply force load balancing with random logits for benchmark
            logits = apply_random_logits(logits)

        probs, routing_map = self.routing(logits)

        return probs, routing_map


class ReLURouter(Router):
    """Route tokens to experts where ReLU(gating) > 0 with L1 load balancing.

    Implements your prior ReLURouter atop the updated routing/aux API.
    """

    def __init__(self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None) -> None:
        super().__init__(config=config, pg_collection=pg_collection)
        self.topk = self.config.moe_router_topk
        self.input_jitter = None

    # Uses base Router.apply_input_jitter

    def l1_reg_load_balancing(self, logits: torch.Tensor):
        """Apply load balancing L1 regularization loss to the ReLU output.

        Args:
            logits (torch.Tensor): Logits tensor after gating, shape: [num_tokens, num_experts].

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment, shape [num_tokens, num_experts].
            routing_map (torch.Tensor): The mapping of token to experts assignment, shape [num_tokens, num_experts].
        """
        probs = torch.relu(logits)
        routing_map = probs > 0
        if self.training and torch.is_grad_enabled():
            tokens_per_expert = routing_map.sum(dim=0)
            tokens_per_expert = reduce_from_tensor_model_parallel_region(
                tokens_per_expert, self.tp_cp_group
            )
            num_tokens = routing_map.shape[0]
            total_num_tokens = num_tokens * self.tp_cp_group.size()

            l1_reg_coeff = float(getattr(self.config, 'moe_relu_l1_reg_coeff', 0.0))
            if l1_reg_coeff > 0.0:
                l1_reg = switch_load_balancing_loss_func(
                    probs=probs,
                    tokens_per_expert=tokens_per_expert,
                    total_num_tokens=total_num_tokens,
                    topk=self.topk,
                    num_experts=self.config.num_moe_experts,
                    moe_aux_loss_coeff=l1_reg_coeff,
                    fused=self.config.moe_router_fusion,
                )
                probs = self.attach_and_log_load_balancing_loss(
                    probs, l1_reg_coeff, l1_reg, "l1_reg_loss", self.tp_cp_group
                )
            # Track sparsity metric for logging via aux tracker to match new style
            with torch.no_grad():
                sparsity = 1.0 - routing_map.sum().float() / float(routing_map.numel())
            num_layers = self.config.num_layers
            if getattr(self.config, 'mtp_num_layers', None) is not None:
                num_layers += self.config.mtp_num_layers
            save_to_aux_losses_tracker(
                "relu_sparsity", sparsity, self.layer_number, num_layers, reduce_group=self.tp_cp_group
            )
        return probs, routing_map

    def routing(self, logits: torch.Tensor):
        # Flatten tokens dimension to [num_tokens, num_experts]
        logits = logits.view(-1, self.config.num_moe_experts)
        probs, routing_map = self.l1_reg_load_balancing(logits)
        return probs, routing_map

    def forward(self, input: torch.Tensor):
        """Forward pass of the ReLU router, aligned with new-style routers."""
        # Apply input jitter
        input = self.apply_input_jitter(input)
        logits = self.gating(input)

        if self.config.moe_router_force_load_balancing:
            logits = apply_random_logits(logits)

        probs, routing_map = self.routing(logits)
        return probs, routing_map


class DirVAERouter(Router):
    """Dirichlet VAE-based router producing soft probabilities with top-k-style mask.

    Ports your DirVAERouter to the updated router API. Attaches reconstruction/KL losses
    via MoEAuxLossAutoScaler and logs metrics through save_to_aux_losses_tracker.
    """

    def __init__(self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None) -> None:
        super().__init__(config=config, pg_collection=pg_collection)
        self.topk = self.config.moe_router_topk
        self.input_jitter = None

        hidden = self.config.hidden_size
        experts = self.config.num_moe_experts
        hsize = int(getattr(self.config, 'dirvae_hidden_size', 512))

        self.encoder = nn.Sequential(
            nn.Linear(hidden, hsize, bias=True),
            nn.ReLU(),
            nn.Linear(hsize, hsize, bias=True),
            nn.ReLU(),
        )
        self.fc_alpha_hi = nn.Linear(hsize, experts, bias=True)
        self.fc_alpha_lo = nn.Linear(hsize, experts, bias=True)
        self.softplus = nn.Softplus(beta=0.5)

        self.decoder = nn.Sequential(
            nn.Linear(experts, hsize, bias=True),
            nn.ReLU(),
            nn.Linear(hsize, hidden, bias=True),
        )

        self.register_buffer('prior_alpha_lo_vec', torch.full((1, experts), 1e-2), persistent=False)
        self.register_buffer('prior_alpha_hi_vec', torch.full((1, experts), 1.0), persistent=False)
        self.register_buffer('tau_z', torch.tensor(float(getattr(self.config, 'router_tau_z', 1.0)), dtype=torch.float32), persistent=False)
        self.alpha_lo_base = float(getattr(self.config, 'dirichlet_prior_alpha_lo', 1e-2))
        self.lambda_p_base = float(getattr(self.config, 'dirichlet_prior_lambda', 1.0))
        self.register_buffer('anneal_step', torch.tensor(0, dtype=torch.long), persistent=False)
        self.register_buffer('anneal_updates', torch.tensor(0, dtype=torch.long), persistent=False)

    def _compute_alpha_q(self, h: torch.Tensor, z_mask: torch.Tensor):
        alpha_hi_hat = self.softplus(self.fc_alpha_hi(h)).clamp(min=1e-6, max=100.0)
        alpha_lo_hat = self.softplus(self.fc_alpha_lo(h)).clamp(min=1e-6, max=10.0)
        alpha_q = float(getattr(self.config, 'router_lambda_q', 1.0)) * (
            z_mask * alpha_hi_hat + (1.0 - z_mask) * alpha_lo_hat
        )
        return alpha_q

    def _compute_alpha_p(self, z_mask: torch.Tensor):
        with torch.no_grad():
            ratio = (self.config.dirichlet_prior_mass_m / max(1e-8, 1.0 - self.config.dirichlet_prior_mass_m)) * (
                (self.config.num_moe_experts - max(1, self.topk)) / max(1, self.topk)
            )
            alpha_lo = torch.full_like(self.prior_alpha_lo_vec, self.config.dirichlet_prior_alpha_lo)
            alpha_hi = ratio * alpha_lo
            self.prior_alpha_lo_vec.copy_(alpha_lo)
            self.prior_alpha_hi_vec.copy_(alpha_hi)
        # Stop-gradient on gates for prior as in Eq. (alpha^(p)(\tilde z_sg))
        z_sg = z_mask.detach()
        alpha_p = float(getattr(self.config, 'dirichlet_prior_lambda', 1.0)) * (
            z_sg * self.prior_alpha_hi_vec + (1.0 - z_sg) * self.prior_alpha_lo_vec
        )
        return alpha_p

    def anneal_prior(self, new_alpha_lo: float, new_lambda_p: float):
        floor = float(getattr(self.config, 'dirichlet_alpha_floor', 1e-4))
        new_alpha_lo = float(max(floor, new_alpha_lo))
        device = self.prior_alpha_lo_vec.device
        dtype = self.prior_alpha_lo_vec.dtype
        alpha_lo_p = torch.full_like(self.prior_alpha_lo_vec, new_alpha_lo, device=device, dtype=dtype)
        with torch.no_grad():
            ratio = (self.config.dirichlet_prior_mass_m / max(1e-8, 1.0 - self.config.dirichlet_prior_mass_m)) * (
                (self.config.num_moe_experts - max(1, self.topk)) / max(1, self.topk)
            )
            alpha_hi_p = ratio * alpha_lo_p
        self.prior_alpha_lo_vec.copy_(alpha_lo_p)
        self.prior_alpha_hi_vec.copy_(alpha_hi_p)
        self.config.dirichlet_prior_lambda = float(new_lambda_p)

    def maybe_update_annealing(self):
        if not self.training or not bool(getattr(self.config, 'anneal_enabled', False)):
            return
        step = int(self.anneal_step.item())
        update_now = (int(getattr(self.config, 'anneal_every', 100)) <= 0) or (step % int(getattr(self.config, 'anneal_every', 100)) == 0)
        if update_now:
            # Smooth multiplicative decay from current value to avoid large jumps
            curr_tau = float(self.tau_z.item())
            tau_min = float(getattr(self.config, 'tau_z_min', 0.2))
            rho = float(getattr(self.config, 'tau_z_rho', 0.999))
            new_tau = max(tau_min, curr_tau * rho)
            self.tau_z.fill_(new_tau)
            self.config.router_tau_z = new_tau
            u = int(self.anneal_updates.item())
            new_alpha_lo = self.alpha_lo_base * (float(getattr(self.config, 'alpha_lo_gamma', 1.0)) ** u)
            new_lambda_p = self.lambda_p_base * (float(getattr(self.config, 'lambda_p_eta', 1.0)) ** u)
            self.anneal_prior(new_alpha_lo=new_alpha_lo, new_lambda_p=new_lambda_p)
            if int(getattr(self.config, 'hard_switch_step', -1)) >= 0 and step >= int(getattr(self.config, 'hard_switch_step', -1)):
                self.config.hard_routing = True
            self.anneal_updates += 1
        self.anneal_step += 1

    def _dirichlet_kl(self, alpha_q: torch.Tensor, alpha_p: torch.Tensor):
        alpha_q = alpha_q.clamp_min(1e-6)
        alpha_p = alpha_p.clamp_min(1e-6)
        sum_q = alpha_q.sum(dim=-1, keepdim=True)
        sum_p = alpha_p.sum(dim=-1, keepdim=True)
        log_beta_q = torch.lgamma(sum_q) - torch.lgamma(alpha_q).sum(dim=-1, keepdim=True)
        log_beta_p = torch.lgamma(sum_p) - torch.lgamma(alpha_p).sum(dim=-1, keepdim=True)
        digamma_q = torch.digamma(alpha_q)
        digamma_sum_q = torch.digamma(sum_q)
        kld = (log_beta_q - log_beta_p + ((alpha_q - alpha_p) * (digamma_q - digamma_sum_q)).sum(dim=-1, keepdim=True)).squeeze(-1)
        return kld.mean()

    def routing(self, logits: torch.Tensor, input: torch.Tensor):
        if self.config.moe_router_force_load_balancing:
            logits = apply_random_logits(logits)


        self.maybe_update_annealing()
        h = self.encoder(input)

        # Stabilize logits: fp32 + per-token centering across experts with TP/CP-safe mean
        logits_fp32 = logits.to(torch.float32)
        # Compute local mean from detached logits to avoid autograd through collectives
        m_local = logits_fp32.detach().mean(dim=-1, keepdim=True)
        if (
            self.tp_cp_group is not None
            and torch.distributed.is_available()
            and torch.distributed.is_initialized()
        ):
            # Work on a detached copy throughout
            m = m_local.clone()
            with torch.no_grad():
                torch.distributed.all_reduce(m, group=self.tp_cp_group)
                world_size = self.tp_cp_group.size()
                m /= max(1, world_size)
        else:
            m = m_local
        # Subtract detached mean to avoid autograd through collectives
        logits_fp32 = logits_fp32 - m.detach()

        # Optional static bias calibration toward E[z] ≈ k/E
        bias_cfg = getattr(self.config, 'gate_logit_bias_calib', None)
        if bias_cfg is not None:
            logits_fp32 = logits_fp32 + float(bias_cfg)
        elif bool(getattr(self.config, 'gate_auto_bias_to_expected_k', False)):
            if not hasattr(self, '_gate_auto_bias_value'):
                k = float(max(1, self.topk))
                E = float(self.config.num_moe_experts)
                tau0 = float(getattr(self.config, 'gate_tau0_for_bias', float(self.tau_z.item())))
                ratio = max(1e-6, k) / max(1e-6, (E - k))
                b = tau0 * math.log(ratio)
                self._gate_auto_bias_value = torch.tensor(b, dtype=torch.float32)
            logits_fp32 = logits_fp32 + self._gate_auto_bias_value.to(device=logits_fp32.device)

        # Optional per-batch calibration to hit expected-k via Newton updates
        tau = max(float(self.tau_z.item()), 1e-6)
        u = torch.rand_like(logits_fp32)
        g = torch.log(u + 1e-9) - torch.log(1.0 - u + 1e-9)
        if bool(getattr(self.config, 'gate_expected_k_calibration', False)):
            delta = torch.zeros((logits_fp32.size(0), 1), dtype=logits_fp32.dtype, device=logits_fp32.device)
            target = float(max(1, self.topk))
            for _ in range(2):
                s = torch.sigmoid((logits_fp32 + g + delta) / tau)
                f = s.sum(dim=-1, keepdim=True) - target
                fp = (s * (1.0 - s) / tau).sum(dim=-1, keepdim=True) + 1e-6
                delta = delta - f / fp
            logits_for_sampling = logits_fp32 + delta.detach()
        else:
            logits_for_sampling = logits_fp32

        # Log pre-sigmoid statistics of l(x)
        l = logits_for_sampling
        l_mean = l.mean()
        l_sum = l.sum()
        l_std = l.float().std()
        l_min = l.min()
        l_max = l.max()
        pos_count = (l > 0).sum().to(dtype=torch.float32)
        total_count = torch.tensor(float(l.numel()), device=l.device)
        pos_frac = pos_count / (total_count + 1e-6)
        router_metrics_push("dirvae_logits_mean", l_mean)
        router_metrics_push("dirvae_logits_sum", l_sum)
        router_metrics_push("dirvae_logits_std", l_std)
        router_metrics_push("dirvae_logits_min", l_min)
        router_metrics_push("dirvae_logits_max", l_max)
        router_metrics_push("dirvae_logits_pos_count", pos_count)
        router_metrics_push("dirvae_logits_pos_frac", pos_frac)

        # Sample Binary-Concrete with calibrated/centered logits and fixed tau
        z_tilde = torch.sigmoid((logits_for_sampling + g) / tau)
        # Use straight-through top-k mask for executed edges
        k = max(1, int(self.topk))
        if bool(getattr(self.config, 'hard_routing', False)) and (not self.training) and self.topk > 0:
            gumbel = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
            scores = (logits + gumbel) / max(float(getattr(self.config, 'hard_tau', 0.5)), 1e-6)
            topk_idx = scores.topk(self.topk, dim=-1).indices
            hard_mask = torch.zeros_like(logits)
            hard_mask.scatter_(dim=-1, index=topk_idx, value=1.0)
            z_mask = hard_mask + z_tilde - z_tilde.detach()
        else:
            z_mask = z_tilde
        print(f"z_mask: {z_mask}")
        alpha_q = self._compute_alpha_q(h, z_mask)
        alpha_p = self._compute_alpha_p(z_mask)
        # Compute Dirichlet in float32 for stability/backward support
        alpha_q_fp32 = alpha_q.to(torch.float32)
        alpha_p_fp32 = alpha_p.to(torch.float32)
        theta_fp32 = Dirichlet(alpha_q_fp32.clamp(min=1e-6)).rsample()
        #print(f"theta_fp32: {theta_fp32}")
        z_mask_fp32 = z_mask.to(torch.float32)
        z_mask_thresh = z_mask_fp32 > float(getattr(self.config, 'routing_threshold', 0.0))
        print(f"z_mask_thresh: {z_mask_thresh}")
        probs_fp32 = z_mask_thresh * theta_fp32
        # Always renormalize after masking to keep mass on executed edges
        # probs_fp32 = probs_fp32 / (probs_fp32.sum(dim=-1, keepdim=True) + 1e-6)
        # Optional epsilon adjustments (merged leak and floor) followed by renormalization
        leak_eps = float(getattr(self.config, 'router_eps_leak', 0.0))
        eps_w = float(getattr(self.config, 'router_eps_w', 0.0))
        if (leak_eps > 0.0) or (eps_w > 0.0):
            if leak_eps > 0.0:
                probs_fp32 = probs_fp32 + (leak_eps / max(1, probs_fp32.size(-1)))
            if eps_w > 0.0:
                probs_fp32 = torch.clamp(probs_fp32, min=eps_w)
            # Single renormalization
            probs_fp32 = probs_fp32 / (probs_fp32.sum(dim=-1, keepdim=True) + 1e-6)
        #print(f"probs_fp32: {probs_fp32}")
        # Construct routing map before epsilon flooring to preserve sparsity
        routing_map = z_mask_fp32 > float(getattr(self.config, 'routing_threshold', 0.0))
        no_sel_rows = ~routing_map.any(dim=1)
        if no_sel_rows.any():
            max_idx = z_mask_fp32.argmax(dim=1, keepdim=True)
            routing_map[no_sel_rows, :] = False
            routing_map[no_sel_rows, :].scatter_(1, max_idx[no_sel_rows], True)

        # Note: eps floor already handled in merged block above
        #print(f"probs_fp32: {probs_fp32}")
        # Cast probs to model/input dtype for decoder and return path
        probs = probs_fp32.to(dtype=input.dtype)
        #print(f"probs: {probs}")
        recon_x = self.decoder(probs)
        recon_loss = F.mse_loss(recon_x, input, reduction='mean')
        # Compute KL in float32
        kld = self._dirichlet_kl(alpha_q_fp32, alpha_p_fp32).to(torch.float32)
        sigma2 = float(getattr(self.config, 'dirvae_sigma2', 1.0))
        rec_term = (0.5 / max(sigma2, 1e-6)) * recon_loss.to(torch.float32)
        beta_theta = float(getattr(self.config, 'dirvae_kl_weight', 1.0))
        #print(f"kld: {kld}")
        k_hat = z_mask.sum(dim=-1)
        #print(f"k_hat: {k_hat}")
        k_target = float(max(1, self.topk))
        expected_k_loss = ((k_hat - k_target) ** 2).mean()
        lambda_card = float(getattr(self.config, 'expected_k_weight', 0.0))
        total_aux = rec_term + beta_theta * kld + lambda_card * expected_k_loss
        #print(f"total_aux: {total_aux}")
        if self.training and torch.is_grad_enabled():
            reduce_group = self.tp_cp_group
            num_layers = self.config.num_layers
            if getattr(self.config, 'mtp_num_layers', None) is not None:
                num_layers += self.config.mtp_num_layers
            # Emit through step-metrics buffer; avoid direct tracker logging from router.
            router_metrics_push("dirvae_kl_loss", kld)
            router_metrics_push("dirvae_recon_loss", rec_term)
            # Log Simpson Index H = sum_i r_i^2 as sparsity/concentration metric
            with torch.no_grad():
                simpson_theta = (theta_fp32 ** 2).sum(dim=-1).mean()
                simpson_probs = (probs_fp32 ** 2).sum(dim=-1).mean()
            router_metrics_push("dirvae_simpson_theta", simpson_theta)
            router_metrics_push("dirvae_simpson_probs", simpson_probs)
            if self.calculate_per_token_loss:
                probs = MoEAuxLossAutoScaler.apply(probs, total_aux * probs.shape[0])
            else:
                probs = MoEAuxLossAutoScaler.apply(probs, total_aux)

        # Extended metrics logging (match prior detailed logging)
        # Choose sequence partition group consistent with dispatcher; compute num_layers for logging
        if getattr(self.config, 'moe_token_dispatcher_type', None) == "alltoall_seq":
            sequence_partition_group = parallel_state.get_context_parallel_group()
        else:
            sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()
        num_layers = self.config.num_layers
        if getattr(self.config, 'mtp_num_layers', None) is not None:
            num_layers += self.config.mtp_num_layers
        # Metrics logging for DirVAE (already logs many metrics; add requested ones)
        metric_group = self.tp_dp_cp_group if self.tp_dp_cp_group is not None else self.tp_cp_group
        with torch.no_grad():
            # Expert load metrics
            num_tokens = probs_fp32.size(0)
            tokens_per_expert = routing_map.sum(dim=0).to(dtype=torch.float32)
            total_routed = tokens_per_expert.sum()
            if metric_group is not None:
                allreduce_(tokens_per_expert, metric_group)
                allreduce_(total_routed, metric_group)
            total_routed = total_routed.clamp_min(torch.tensor(1.0, device=probs.device))
            expert_load_fraction = tokens_per_expert / total_routed
            #print(f"expert_load_fraction: {expert_load_fraction}")
            # Triage metrics
            # 1) Max of probs
            router_metrics_push("triage_probs_max", probs_fp32.max())

            # 2) Padding waste (capacity_factor matters)
            capacity_factor = getattr(self.config, 'moe_expert_capacity_factor', None)
            if capacity_factor is not None:
                cap = int(get_capacity(
                    num_tokens=num_tokens * max(1, self.topk),
                    num_experts=self.config.num_moe_experts,
                    capacity_factor=float(capacity_factor),
                ))
                used = torch.clamp(tokens_per_expert, max=cap).sum()
                denom = (cap * self.config.num_moe_experts) + 1e-6
                waste_frac = 1.0 - (used / denom)
                router_metrics_push("triage_capacity_waste", waste_frac)
            else:
                router_metrics_push("triage_capacity_waste", torch.tensor(0.0, device=probs.device))

            # 3) Thresholding/selection diagnostics
            thr = float(getattr(self.config, 'routing_threshold', 0.0))
            router_metrics_push("triage_threshold", torch.tensor(thr, device=probs.device))
            router_metrics_push("triage_avg_z_mean", z_mask_fp32.mean())
            router_metrics_push("triage_avg_r_top1", probs_fp32.topk(1, dim=-1).values.mean())
            # Average and approx k
            avg_k = routing_map.sum(dim=1).float().mean()
            avg_k_cpu = routing_map.sum(dim=1).float().mean().to('cpu', non_blocking=True)
            #print(f"avg_k_cpu: {avg_k_cpu}")
            approx_k = max(1, int(torch.round(avg_k_cpu).item()))
            #print(f"approx_k: {approx_k}")
            # Build a top-k mask ONLY for metrics/LB (not changing dispatch)
            k_metric = max(1, int(self.topk))
            metric_topk_idx = z_tilde.topk(k_metric, dim=-1).indices
            routing_map_metric = torch.zeros_like(z_tilde, dtype=torch.bool)
            routing_map_metric.scatter_(1, metric_topk_idx, True)

            # Debug: executed set size and row mass under metric mask
            dbg_avg_k_for_lb = routing_map_metric.sum(dim=1).float().mean()
            router_metrics_push("dbg_avg_k_for_lb", dbg_avg_k_for_lb)
            row_sums = (probs_fp32 * routing_map_metric.to(probs_fp32.dtype)).sum(dim=-1)
            router_metrics_push("dbg_row_sum_mean", row_sums.mean())
            router_metrics_push("dbg_row_sum_min", row_sums.min())
            router_metrics_push("dbg_row_sum_max", row_sums.max())

            # Load-balance metric (for monitoring only) using globally normalized fractions, on metric mask
            # Ensure per-token normalization under metric mask before LB
            pm = probs_fp32 * routing_map_metric.to(probs_fp32.dtype)
            pm = pm / (pm.sum(dim=-1, keepdim=True) + 1e-6)
            lb_metric_probs = switch_load_balance_loss(
                routing_map=routing_map_metric,
                r_probs=pm.detach(),
                group=metric_group,
            )
            tm = theta_fp32 * routing_map_metric.to(theta_fp32.dtype)
            tm = tm / (tm.sum(dim=-1, keepdim=True) + 1e-6)
            lb_metric_theta = switch_load_balance_loss(
                routing_map=routing_map_metric,
                r_probs=tm.detach(),
                group=metric_group,
            )
            #print(f"lb_metric_probs: {lb_metric_probs}")
            #print(f"lb_metric_theta: {lb_metric_theta}")
            router_metrics_push("dirvae_avg_k", avg_k)
            # Annealing and prior scalars
            tau_z_val = torch.tensor(float(self.tau_z.item()), device=probs.device)
            lambda_p_val = torch.tensor(float(self.config.dirichlet_prior_lambda), device=probs.device)
            lambda_q_val = torch.tensor(float(getattr(self.config, 'router_lambda_q', 1.0)), device=probs.device)
            alpha_lo_mean = self.prior_alpha_lo_vec.mean()
            alpha_hi_mean = self.prior_alpha_hi_vec.mean()
            
            # Max k across ranks
            local_max_k = routing_map.sum(dim=1).to(dtype=torch.int32).max()
            max_k = local_max_k.to(dtype=torch.float32)
            if torch.distributed.is_available() and torch.distributed.is_initialized():
                try:
                    torch.distributed.all_reduce(max_k, group=sequence_partition_group, op=torch.distributed.ReduceOp.MAX)
                except Exception:
                    pass
                try:
                    torch.distributed.all_reduce(max_k, group=parallel_state.get_pipeline_model_parallel_group(), op=torch.distributed.ReduceOp.MAX)
                except Exception:
                    pass
            #print(f"max_k: {max_k}")
            num_experts_effective = torch.tensor(float(self.config.num_moe_experts), device=probs.device)
            tp_ep_ws = torch.tensor(float(parallel_state.get_tensor_and_context_parallel_world_size()), device=probs.device)
            pp_ws = torch.tensor(float(parallel_state.get_pipeline_model_parallel_world_size()), device=probs.device)

            router_metrics_push("dirvae_tau_z", tau_z_val)
            router_metrics_push("dirvae_lambda_p", lambda_p_val)
            router_metrics_push("dirvae_lambda_q", lambda_q_val)
            router_metrics_push("dirvae_max_k", max_k)
            router_metrics_push("dirvae_effective_num_experts", num_experts_effective)
            router_metrics_push("dirvae_metric_group_world_size", tp_ep_ws)
            router_metrics_push("dirvae_pipeline_world_size", pp_ws)

            # Means of annealing-affected quantities per batch
            alpha_q_mean = alpha_q.mean()
            alpha_q_sum_mean = alpha_q.sum(dim=-1).mean()
            alpha_p_mean = alpha_p.mean()
            alpha_p_sum_mean = alpha_p.sum(dim=-1).mean()
            z_mean = z_mask.mean()
            probs_mean = probs.mean()
            #print(f"alpha_q_mean: {alpha_q_mean}")
            #print(f"probs_mean: {probs_mean}")
            router_metrics_push("dirvae_alpha_q_mean", alpha_q_mean)
            router_metrics_push("dirvae_alpha_q_sum_mean", alpha_q_sum_mean)
            router_metrics_push("dirvae_alpha_p_mean", alpha_p_mean)
            router_metrics_push("dirvae_alpha_p_sum_mean", alpha_p_sum_mean)
            router_metrics_push("dirvae_z_mean", z_mean)
            router_metrics_push("dirvae_probs_mean", probs_mean)
            router_metrics_push("dirvae_prior_alpha_lo_mean", alpha_lo_mean)
            router_metrics_push("dirvae_prior_alpha_hi_mean", alpha_hi_mean)
            router_metrics_push("dirvae_lb_metric_probs", lb_metric_probs)
            router_metrics_push("dirvae_lb_metric_theta", lb_metric_theta)
            # 0–1 scaled imbalance for dashboards: 0 best → 1 worst
            E_val = torch.tensor(float(self.config.num_moe_experts), device=probs.device)
            denom = (E_val - 1.0).clamp_min(1e-6)
            imbalance_probs_01 = ((lb_metric_probs - 1.0) / denom).clamp(0.0, 1.0)
            imbalance_theta_01 = ((lb_metric_theta - 1.0) / denom).clamp(0.0, 1.0)
            router_metrics_push("dirvae_lb_metric_probs_01", imbalance_probs_01)
            router_metrics_push("dirvae_lb_metric_theta_01", imbalance_theta_01)
            # Gold Switch LB with GLOBAL normalization (fractions over global N)
            E_gold = probs_fp32.size(-1)
            N_local = probs_fp32.size(0)
            tokens_per_expert_local = routing_map.sum(dim=0).to(torch.float32)
            p_sum_local = probs_fp32.sum(dim=0)
            gold_group = metric_group
            tokens_per_expert_g = tokens_per_expert_local.clone()
            p_sum_g = p_sum_local.clone()
            N = torch.tensor(float(N_local), device=probs_fp32.device)
            if dist.is_available() and dist.is_initialized() and gold_group is not None:
                dist.all_reduce(tokens_per_expert_g, group=gold_group)
                dist.all_reduce(p_sum_g, group=gold_group)
                dist.all_reduce(N, group=gold_group)
            f_gold = expert_load_fraction
            p_gold = p_sum_g / (N + 1e-6)
            lb_gold = torch.tensor(float(E_gold), device=probs_fp32.device) * (f_gold * p_gold).sum()

            router_metrics_push("lb_gold_value", lb_gold)
            router_metrics_push("lb_gold_sum_f", f_gold.sum())
            router_metrics_push("lb_gold_sum_p", p_gold.sum())
            router_metrics_push("lb_gold_f_max", f_gold.max())
            router_metrics_push("lb_gold_p_max", p_gold.max())
            # Sanity: sums of f and p from executed mass should be ≈ 1; also min/max and entropies
            stats_probs = global_load_and_importance(
                routing_map=routing_map,
                r_probs=probs_fp32.detach(),
                group=metric_group,
            )
            f = stats_probs["f"].clamp_min(1e-12)
            p = stats_probs["p"].clamp_min(1e-12)
            router_metrics_push("dirvae_lb_f_sum", f.sum())
            router_metrics_push("dirvae_lb_p_sum", p.sum())
            router_metrics_push("dirvae_lb_f_min", f.min())
            router_metrics_push("dirvae_lb_f_max", f.max())
            router_metrics_push("dirvae_lb_p_min", p.min())
            router_metrics_push("dirvae_lb_p_max", p.max())
            H_f = (-f * torch.log(f)).sum()
            H_p = (-p * torch.log(p)).sum()
            router_metrics_push("dirvae_lb_f_entropy", H_f)
            router_metrics_push("dirvae_lb_p_entropy", H_p)

            # Soft expected-k and its loss
            expected_k = k_hat.mean()
            router_metrics_push("dirvae_expected_k", expected_k.detach())
            router_metrics_push("dirvae_expected_k_loss", expected_k_loss.detach())

            # Per-expert load fractions
            for i in range(expert_load_fraction.numel()):
                router_metrics_push(f"dirvae_expert_load_e{i}", expert_load_fraction[i])
        return probs, routing_map

    def forward(self, input: torch.Tensor):
        # Log that DirVAE router is active for this layer
        if getattr(self.config, 'moe_token_dispatcher_type', None) == "alltoall_seq":
            sequence_partition_group = parallel_state.get_context_parallel_group()
        else:
            sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()
        num_layers = self.config.num_layers
        if getattr(self.config, 'mtp_num_layers', None) is not None:
            num_layers += self.config.mtp_num_layers
        router_metrics_push(
            "dirvae_router_active",
            torch.tensor(1.0, device=input.device),
        )
        router_metrics_push("dirvae_router_active", torch.tensor(1.0, device=input.device))

        input = self.apply_input_jitter(input)
        logits = self.gating(input)
        logits = logits.view(-1, self.config.num_moe_experts)
        input2d = input.view(-1, self.config.hidden_size)
        probs, routing_map = self.routing(logits, input2d)
        # Metrics logging for DirVAE (already logs many metrics; add requested ones)
        metric_group = self.tp_dp_cp_group if self.tp_dp_cp_group is not None else self.tp_cp_group

        with torch.no_grad():
            probs_fp32 = probs.to(torch.float32)
            # Ensure per-token normalization on [S,B,E] before LB diagnostics
            E = int(self.config.num_moe_experts)
            # input here is still the unflattened tensor from forward()
            # Recover (S,B) if shape matches; otherwise skip
            total_tokens = probs_fp32.numel() // max(1, E)
            # Derive seq_len and bsz from the surrounding scope if available via logits/input shapes
            # We prefer reading from logits shape set in forward()
            # logits is [S*B, E], so cannot directly recover S,B; use input instead
            # Guard: only reshape if divisible
            if hasattr(input, 'shape') and len(input.shape) >= 2:
                seq_len = int(input.shape[0])
                bsz = int(input.shape[1])
                if seq_len * bsz == total_tokens:
                    pv = probs_fp32.view(seq_len, bsz, E)
                    pv = pv / (pv.sum(dim=-1, keepdim=True) + 1e-6)
                    probs_fp32 = pv.view(-1, E)
            H = simpson_index_r(probs_fp32)
            card = routing_cardinality_stats(routing_map)
            # Build gate-like scores from logits and tau for consistency
            logits_flat_fp32 = logits.view(-1, self.config.num_moe_experts).to(torch.float32)
            tau = max(float(self.tau_z.item()), 1e-6)
            z_scores = torch.sigmoid(logits_flat_fp32 / tau)
            ek_err = expected_k_error(z_scores, target_k=float(max(1, int(self.topk))))
            leak = leakage_metrics(r_probs=probs_fp32, z_scores=z_scores, k=max(1, int(self.topk)), use_topk_on="z")
            # Use top-k metric mask here as well to match intended executed set
            try:
                k_metric2 = max(1, int(self.topk))
                metric_topk_idx2 = z_scores.topk(k_metric2, dim=-1).indices
                routing_map_metric2 = torch.zeros_like(routing_map, dtype=torch.bool)
                routing_map_metric2.scatter_(1, metric_topk_idx2, True)
            except Exception:
                routing_map_metric2 = routing_map
            lb_loss = switch_load_balance_loss(
                routing_map=routing_map_metric2,
                r_probs=probs_fp32,
                group=metric_group,
            )
            router_metrics_push("m_sparsity_simpson_H", H)
            router_metrics_push("m_card_k_avg", card["k_avg"]) 
            router_metrics_push("m_card_k_max", card["k_max"]) 
            router_metrics_push("m_expected_k_error", ek_err)
            router_metrics_push("m_leak_mean", leak["leak_mean"]) 
            router_metrics_push("m_mass_active_mean", leak["mass_active_mean"]) 
            try:
                cap_factor = getattr(self.config, 'moe_expert_capacity_factor', None)
                if cap_factor is not None:
                    cap_stats = capacity_stats(
                        routing_map=routing_map,
                        capacity_factor=float(cap_factor),
                        k=max(1, int(self.topk)),
                        group=metric_group,
                    )
                    router_metrics_push("m_cap_drop_rate", cap_stats["drop_rate"]) 
                    router_metrics_push("m_cap_padding_waste", cap_stats["padding_waste"]) 
                    router_metrics_push("m_cap_tokens_per_expert_mean", cap_stats["tokens_per_expert_mean"]) 
                else:
                    router_metrics_push("m_cap_drop_rate", torch.tensor(0.0, device=probs.device))
                    router_metrics_push("m_cap_padding_waste", torch.tensor(0.0, device=probs.device))
            except Exception:
                pass
            router_metrics_push("m_lb_switch", lb_loss)
        return probs, routing_map
