# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import torch
from typing import List, Dict, Tuple, Optional

from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region

# A global dictionary to keep all layer probabilities during a forward pass
_PATHGUARD_DP_PROBS_TRACKER = {}


def _calc_pathguard_dp_loss(probs_dict: Dict[int, torch.Tensor], coeff: float, k: int = 1):
    """
    Compute PathGuard_DP loss using dynamic programming to maximize the probability of complete paths.
    
    Unlike PathGuard which only considers consecutive layer pairs, PathGuard_DP uses dynamic programming
    to identify and enhance the highest probability paths from input to output through all MoE layers.
    
    Args:
        probs_dict: Dictionary mapping layer indices to probability tensors [num_tokens, num_experts]
        coeff: Scaling coefficient for the loss
        k: Number of top paths to consider in the dynamic programming algorithm
        
    Returns:
        torch.Tensor: The PathGuard_DP loss (scaled by coeff)
    """
    # Sort layers by index to ensure proper ordering
    layer_indices = sorted(probs_dict.keys())
    if len(layer_indices) < 2:
        # Need at least two layers to compute path probabilities
        return torch.tensor(0.0, device=probs_dict[layer_indices[0]].device)
    
    # Extract model dimensions
    num_layers = len(layer_indices)
    first_layer = layer_indices[0]
    last_layer = layer_indices[-1]
    num_tokens = probs_dict[first_layer].shape[0]
    num_experts = probs_dict[first_layer].shape[1]
    
    # Create device variable for tensor allocation
    device = probs_dict[first_layer].device
    
    # Initialize DP table to store path probabilities
    # dp[l][i][j] = probability of path from expert i in layer 0 to expert j in layer l
    dp = {}
    
    # Tracking dictionaries for best paths removed as they are not used in loss computation.
    
    # Initialize with the first layer's probabilities
    dp[first_layer] = {}
    for expert in range(num_experts):
        dp[first_layer][expert] = probs_dict[first_layer][:, expert].unsqueeze(1)
    
    # Compute joint probabilities between consecutive layers
    joint_probs = {}
    for l_idx in range(len(layer_indices) - 1):
        curr_layer = layer_indices[l_idx]
        next_layer = layer_indices[l_idx + 1]
        
        # Compute joint probability matrix P_{ij} = 1/T * sum_t g_{ti}^{(L)} g_{tj}^{(L+1)}
        # Shape: [num_experts_L, num_experts_L+1]
        joint = probs_dict[curr_layer].T @ probs_dict[next_layer]
        joint = joint / max(num_tokens, 1)
        joint_probs[(curr_layer, next_layer)] = joint
    
    # DP forward pass to compute path probabilities
    for l_idx in range(1, len(layer_indices)):
        curr_layer = layer_indices[l_idx]
        prev_layer = layer_indices[l_idx - 1]
        dp[curr_layer] = {}
        
        for curr_expert in range(num_experts):
            # For each expert in current layer, compute path probabilities from all
            # previous layer experts
            path_probs = []
            path_sources = []
            
            for prev_expert in range(num_experts):
                # Transition probability from prev_expert to curr_expert
                transition_prob = joint_probs[(prev_layer, curr_layer)][prev_expert, curr_expert]
                
                # For each token, compute the probability of reaching curr_expert through prev_expert
                token_path_probs = dp[prev_layer][prev_expert] * transition_prob
                path_probs.append(token_path_probs)
                path_sources.append(prev_expert)
            
            # Stack all path probabilities [num_tokens, num_experts_prev]
            all_path_probs = torch.cat([p for p in path_probs], dim=1)
            
            # Select top-k paths for each token
            if k < num_experts:
                topk_probs, topk_indices = torch.topk(all_path_probs, k=min(k, all_path_probs.shape[1]), dim=1)
                # Store the selected top-k path probabilities
                dp[curr_layer][curr_expert] = topk_probs
            else:
                # If k >= num_experts, keep all paths
                dp[curr_layer][curr_expert] = all_path_probs
    
    # Compute the final path probabilities to the last layer
    final_path_probs = torch.cat([dp[last_layer][expert] for expert in range(num_experts)], dim=1)
    
    # Take the top-k paths overall
    topk_final_probs, _ = torch.topk(final_path_probs, k=min(k, final_path_probs.shape[1]), dim=1)
    
    # The loss is negative log of the sum of top-k path probabilities
    eps = 1e-9  # Small epsilon to avoid log(0)
    pg_dp_loss = -torch.log(torch.sum(topk_final_probs) + eps) * coeff
    
    return pg_dp_loss


def compute_and_register_pathguard_dp_loss(
    layer_number: int,
    probs: torch.Tensor,
    coeff: float,
    k: int = 1,
    sequence_partition_group=None,
):
    """
    Register current layer probabilities and compute PathGuard_DP loss if all needed layers are present.
    
    Args:
        layer_number: Index of current MoE layer (starting from 0).
        probs: Soft probabilities of current layer [num_tokens, num_experts].
        coeff: Scaling coefficient.
        k: Number of top paths to consider.
        sequence_partition_group: Parallel group for sequence partition.
        
    Returns:
        Tuple[Optional[torch.Tensor], Optional[int]]: (loss tensor if computed else None,
        the layer index the loss is associated with)
    """
    # Register current probs
    global _PATHGUARD_DP_PROBS_TRACKER

    # Register current probs
    _PATHGUARD_DP_PROBS_TRACKER[layer_number] = probs
    
    # Only compute loss at the last layer
    loss = None
    target_layer = None
    
    # Get all layer probabilities and gather if using sequence parallelism
    probs_dict = {}
    for l_idx, l_probs in _PATHGUARD_DP_PROBS_TRACKER.items():
        if sequence_partition_group is not None and sequence_partition_group.size() > 1:
            probs_dict[l_idx] = gather_from_sequence_parallel_region(l_probs, group=sequence_partition_group)
        else:
            probs_dict[l_idx] = l_probs
    
    # Compute loss if we have all layers and this is the last layer
    if len(probs_dict) > 1:
        # Check if current layer is the last MoE layer so far
        if layer_number == max(probs_dict.keys()):
            loss = _calc_pathguard_dp_loss(probs_dict, coeff, k)
            target_layer = layer_number
    
    # Clean up old layer data that's no longer needed
    # Keep only the most recent layers to control memory usage
    max_layers_to_keep = 20  # Adjustable parameter for memory management
    if len(_PATHGUARD_DP_PROBS_TRACKER) > max_layers_to_keep:
        # Keep only the most recent layers
        layers_to_keep = sorted(_PATHGUARD_DP_PROBS_TRACKER.keys(), reverse=True)[:max_layers_to_keep]
        _PATHGUARD_DP_PROBS_TRACKER = {k: _PATHGUARD_DP_PROBS_TRACKER[k] for k in layers_to_keep}
    
    return loss, target_layer