import math
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from utils.registry import CONNECTOR


class DeepseekMLP(nn.Module):
    """
    Multi-Layer Perceptron (MLP) module for Deepseek architecture.
    
    This module implements a standard MLP with gating mechanism, following
    the architecture used in modern transformer models.
    """
    
    def __init__(self, config, hidden_size=None, intermediate_size=None):
        """
        Initialize the Deepseek MLP module.
        
        Args:
            config (dict): Configuration dictionary containing model parameters
            hidden_size (int, optional): Size of hidden layer. If None, uses config value
            intermediate_size (int, optional): Size of intermediate layer. If None, uses config value
        """
        super().__init__()
        self.config = config
        self.pretraining_tp = config.get("pretraining_tp")
        self.hidden_size = config.get("hidden_size") if hidden_size is None else hidden_size
        self.intermediate_size = config.get("intermediate_size") if intermediate_size is None else intermediate_size
        self.output_size = config.get("output_size")

        # Linear projection layers for gating mechanism
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)    # Gate projection
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)      # Up projection
        self.down_proj = nn.Linear(self.intermediate_size, self.output_size, bias=False)    # Down projection
        self.act_fn = ACT2FN[config.get("hidden_act")]                                      # Activation function

    def forward(self, x):
        """
        Forward pass of the MLP module.
        
        Implements tensor parallelism when pretraining_tp > 1 by splitting computations
        across multiple slices to accelerate training.
        
        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
            
        Returns:
            torch.Tensor: Output tensor of shape [batch_size, seq_len, output_size]
        """
        if self.pretraining_tp > 1:
            # Tensor parallelism: split computation across multiple slices
            slice_size = self.intermediate_size // self.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice_size, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice_size, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice_size, dim=1)
            
            # Apply gate and up projections to each slice
            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1
            )
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)

            # Apply activation and element-wise multiplication, then split for down projection
            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice_size, dim=-1)
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)
            ]
            down_proj = sum(down_proj)
        else:
            # Standard forward pass: gate * up -> activation -> down projection
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

class MoEGate(nn.Module):
    """
    Mixture of Experts (MoE) Gating mechanism.
    
    This module implements the gating function that routes tokens to different experts
    based on learned routing weights, with support for auxiliary losses to encourage
    balanced expert utilization.
    """
    
    def __init__(self, config):
        """
        Initialize the MoE gate.
        
        Args:
            config (dict): Configuration dictionary containing gate parameters
        """
        super().__init__()
        self.config = config
        self.fusion_type = config.get("fusion_type")                    # Fusion strategy: "joint", "per_modality", or "disjoint"
        self.top_k = config.get("num_experts_per_tok")                  # Number of experts selected per token
        
        # Adjust number of routed experts based on fusion type
        if self.fusion_type == "disjoint":
            # In disjoint mode, divide experts by number of modalities
            self.n_routed_experts = config.get("n_routed_experts") // 2
        else:
            self.n_routed_experts = config.get("n_routed_experts")

        self.scoring_func = config.get("scoring_func")                  # Scoring function (e.g., "softmax")
        self.alpha = config.get("aux_loss_alpha")                       # Auxiliary loss coefficient
        self.seq_aux = config.get("seq_aux")                           # Whether to use sequence auxiliary loss

        # Top-k selection parameters
        self.norm_topk_prob = config.get("norm_topk_prob")             # Whether to normalize top-k expert weights
        self.gating_dim = config.get("hidden_size")                    # Gating network input dimension
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))  # Expert scoring weights
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the weight matrix using Kaiming uniform initialization."""
        import torch.nn.init as init
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, hidden_states):
        """
        Forward pass of the MoE gate.
        
        Computes expert selection and routing weights for input tokens.
        
        Args:
            hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
            
        Returns:
            tuple: A tuple containing:
                - topk_idx (torch.Tensor): Indices of selected experts [batch_size*seq_len, top_k]
                - topk_weight (torch.Tensor): Weights for selected experts [batch_size*seq_len, top_k]  
                - aux_loss (torch.Tensor or None): Auxiliary loss for load balancing
        """
        bsz, seq_len, h = hidden_states.shape
        
        # Compute gating scores
        hidden_states = hidden_states.reshape(-1, h)                   # Flatten to [batch_size*seq_len, hidden_size]
        logits = F.linear(hidden_states, self.weight, None)            # Expert logits: [batch_size*seq_len, n_routed_experts]
        
        if self.scoring_func == 'softmax':
            scores = logits.softmax(dim=-1)                            # Normalize scores across experts
        else:
            raise NotImplementedError(f'Unsupported scoring function for MoE gating: {self.scoring_func}')

        # Select top-k experts
        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        # Normalize gate weights to sum to 1
        if self.top_k > 1 and self.norm_topk_prob:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator

        # Compute auxiliary loss for expert load balancing
        if self.training and self.alpha > 0.0:
            scores_for_aux = scores
            aux_topk = self.top_k
            topk_idx_for_aux_loss = topk_idx.reshape(bsz, -1)
            
            if self.seq_aux:
                # Sequence-level auxiliary loss: compute expert counts per sequence
                scores_for_seq_aux = scores_for_aux.reshape(bsz, seq_len, -1)
                ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
                # Count expert usage and normalize
                ce.scatter_add_(1, topk_idx_for_aux_loss, 
                              torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
                              seq_len * aux_topk / self.n_routed_experts)
                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
            else:
                # Global auxiliary loss: compute average expert counts across all sequences
                mask_ce = F.one_hot(topk_idx_for_aux_loss.reshape(-1), num_classes=self.n_routed_experts)
                ce = mask_ce.float().mean(0)
                Pi = scores_for_aux.mean(0)
                fi = ce * self.n_routed_experts
                aux_loss = (Pi * fi).sum() * self.alpha
        else:
            aux_loss = None
            
        return topk_idx, topk_weight, aux_loss


class AddAuxiliaryLoss(torch.autograd.Function):
    """
    Utility function for adding auxiliary loss during backpropagation.
    
    This autograd function enables the auxiliary loss gradient to be included
    in the backward pass without affecting the forward computation.
    """
    
    @staticmethod
    def forward(ctx, x, loss):
        """
        Forward pass that stores loss information for backward.
        
        Args:
            ctx: Context object to store information for backward pass
            x (torch.Tensor): Input tensor to pass through unchanged
            loss (torch.Tensor): Auxiliary loss tensor (must be scalar)
            
        Returns:
            torch.Tensor: Input tensor x unchanged
        """
        assert loss.numel() == 1
        ctx.dtype = loss.dtype
        ctx.required_aux_loss = loss.requires_grad
        return x

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass that includes auxiliary loss gradient.
        
        Args:
            ctx: Context object with stored information
            grad_output (torch.Tensor): Gradient from subsequent layers
            
        Returns:
            tuple: Gradients for (x, loss) inputs
        """
        grad_loss = None
        if ctx.required_aux_loss:
            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
        return grad_output, grad_loss


@CONNECTOR.register("deepseek_moe")
class DeepseekMoE(nn.Module):
    """
    Deepseek Mixture of Experts (MoE) module with support for multiple fusion strategies.
    
    This module implements a flexible MoE architecture that can handle multiple modalities
    with different fusion strategies:
    - joint: All modalities share experts and routers
    - per_modality: Shared experts but separate routers per modality  
    - disjoint: Separate experts and routers per modality
    """
    
    def __init__(self, config, num_modality, device):
        """
        Initialize the Deepseek MoE module.
        
        Args:
            config (dict): Configuration dictionary containing MoE parameters
            num_modality (int): Number of input modalities
            device: Device to place the module on
        """
        super().__init__()
        self.device = device
        self.config = config.get("deepseek_moe")
        self.num_experts_per_tok = self.config.get("num_experts_per_tok")
        self.moe_intermediate_size = self.config.get("moe_intermediate_size")
        self.n_routed_experts = self.config.get("n_routed_experts")
        self.n_shared_experts = self.config.get("n_shared_experts")
        self.num_modality = num_modality
        self.fusion_type = self.config.get("fusion_type", "disjoint")

        if self.fusion_type == "joint":
            # All modalities share a single set of experts and router
            self.experts = nn.ModuleList([
                DeepseekMLP(self.config, intermediate_size=self.moe_intermediate_size) 
                for _ in range(self.n_routed_experts)
            ])
            self.gate = MoEGate(self.config)
            
        elif self.fusion_type == "per_modality":
            # Shared experts but independent routers per modality
            self.experts = nn.ModuleList([
                DeepseekMLP(self.config, intermediate_size=self.moe_intermediate_size) 
                for _ in range(self.n_routed_experts)
            ])
            self.gates = nn.ModuleList([MoEGate(self.config) for _ in range(self.num_modality)])
            
        elif self.fusion_type == "disjoint":
            # Separate experts and routers for each modality
            self.experts = nn.ModuleList()
            experts_per_modality = self.n_routed_experts // self.num_modality
            split_points = [i * experts_per_modality for i in range(self.num_modality + 1)]
            
            # Distribute experts across modalities
            for i in range(self.num_modality):
                start = split_points[i]
                end = split_points[i + 1]
                modality_expert_indices = list(range(start, end))
                modality_experts = nn.ModuleList([
                    DeepseekMLP(self.config, intermediate_size=self.moe_intermediate_size) 
                    for _ in modality_expert_indices
                ])
                self.experts.append(modality_experts)
            self.gates = nn.ModuleList([MoEGate(self.config) for _ in range(self.num_modality)])
        else:
            raise ValueError(f"Invalid fusion_type: {self.fusion_type}")

        # Initialize shared experts if specified
        if self.n_shared_experts is not None:
            intermediate_size = self.moe_intermediate_size * self.n_shared_experts
            self.shared_experts = DeepseekMLP(config=self.config, intermediate_size=intermediate_size)

    def forward(self, hidden_states):
        """
        Forward pass through the MoE module.
        
        Routes input through the appropriate fusion strategy based on configuration.
        
        Args:
            hidden_states (list): List of hidden state tensors for each modality
            
        Returns:
            torch.Tensor: Fused output tensor combining all modalities
        """
        if self.fusion_type == "joint":
            return self._forward_joint(hidden_states)
        elif self.fusion_type == "per_modality":
            return self._forward_per_modality(hidden_states)
        elif self.fusion_type == "disjoint":
            return self._forward_disjoint(hidden_states)

    def _forward_joint(self, hidden_states):
        """
        Joint experts & router forward pass.
        
        Concatenates all modality inputs and processes them through shared experts.
        
        Args:
            hidden_states (list): List of hidden state tensors for each modality
            
        Returns:
            torch.Tensor: Output tensor after joint processing
        """
        # Concatenate all modality inputs
        hidden_states = torch.cat(hidden_states, dim=1)

        identity = hidden_states
        orig_shape = hidden_states.shape
        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
        hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
        flat_topk_idx = topk_idx.reshape(-1)

        if self.training:
            hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
            y = torch.empty_like(hidden_states)
            for i, expert in enumerate(self.experts):
                y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])

            y = (y.reshape(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
            y = y.reshape(*orig_shape)
            y = AddAuxiliaryLoss.apply(y, aux_loss)
        else:
            y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.reshape(-1, 1)).reshape(*orig_shape)

        # Add shared expert output if available
        if self.n_shared_experts is not None:
            y = y + self.shared_experts(identity)

        return y

    def _forward_per_modality(self, hidden_states):
        """
        Per-modality routers forward pass.
        
        Uses shared experts but separate routers for each modality.
        
        Args:
            hidden_states (list): List of hidden state tensors for each modality
            
        Returns:
            torch.Tensor: Fused output tensor after per-modality routing
        """
        all_expert_outputs = []
        all_aux_losses = []

        for i, hidden_state in enumerate(hidden_states):
            topk_idx, topk_weight, aux_loss = self.gates[i](hidden_state)
            hidden_state = hidden_state.reshape(-1, hidden_state.shape[-1])
            flat_topk_idx = topk_idx.reshape(-1)

            if self.training:
                hidden_state = hidden_state.repeat_interleave(self.num_experts_per_tok, dim=0)
                y = torch.empty_like(hidden_state)
                for j, expert in enumerate(self.experts):
                    y[flat_topk_idx == j] = expert(hidden_state[flat_topk_idx == j])

                y = (y.reshape(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
                y = y.reshape(*hidden_states[i].shape)
                y = AddAuxiliaryLoss.apply(y, aux_loss)
            else:
                y = self.moe_infer(hidden_state, flat_topk_idx, topk_weight.reshape(-1, 1)).reshape(*hidden_states[i].shape)

            all_expert_outputs.append(y)
            all_aux_losses.append(aux_loss)

        # Fuse outputs from all modality routed experts
        fused_output = torch.cat(all_expert_outputs, dim=1)
        
        # Add shared expert outputs if available
        if self.n_shared_experts is not None:
            all_shared_expert_outputs = []
            for _, hidden_state in enumerate(hidden_states):
                shared_expert_output = self.shared_experts(hidden_state)
                all_shared_expert_outputs.append(shared_expert_output)

            # Fuse shared expert outputs from all modalities
            shared_fused_output = torch.cat(all_shared_expert_outputs, dim=1)
            fused_output = fused_output + shared_fused_output

        return fused_output

    def _forward_disjoint(self, hidden_states):
        """
        Disjoint experts & routers forward pass.
        
        Each modality has its own dedicated experts and router.
        
        Args:
            hidden_states (list): List of hidden state tensors for each modality
            
        Returns:
            torch.Tensor: Fused output tensor after disjoint processing
        """
        all_expert_outputs = []
        all_aux_losses = []

        for i, hidden_state in enumerate(hidden_states):
            topk_idx, topk_weight, aux_loss = self.gates[i](hidden_state)
            hidden_state = hidden_state.reshape(-1, hidden_state.shape[-1])
            flat_topk_idx = topk_idx.reshape(-1)

            if self.training:
                hidden_state = hidden_state.repeat_interleave(self.num_experts_per_tok, dim=0)
                y = torch.empty_like(hidden_state).to(dtype=torch.bfloat16)
                
                # Use modality-specific experts
                for j, expert in enumerate(self.experts[i]):
                    y[flat_topk_idx == j] = expert(hidden_state[flat_topk_idx == j]).to(y.dtype)

                y = (y.reshape(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
                y = y.reshape(*hidden_states[i].shape)
                y = AddAuxiliaryLoss.apply(y, aux_loss)
            else:
                y = self.moe_infer(hidden_state, flat_topk_idx, topk_weight.reshape(-1, 1), ind=i).reshape(*hidden_states[i].shape)
                
            all_expert_outputs.append(y)
            all_aux_losses.append(aux_loss)
        
        # Fuse outputs from all modality experts
        fused_output = torch.cat(all_expert_outputs, dim=1)
        
        # Add shared expert outputs if available
        if self.n_shared_experts is not None:
            all_shared_expert_outputs = []
            for _, hidden_state in enumerate(hidden_states):
                shared_expert_output = self.shared_experts(hidden_state)
                all_shared_expert_outputs.append(shared_expert_output)

            # Fuse shared expert outputs from all modalities
            shared_fused_output = torch.cat(all_shared_expert_outputs, dim=1)
            fused_output = fused_output + shared_fused_output

        return fused_output

    @torch.no_grad()
    def moe_infer(self, x, flat_expert_indices, flat_expert_weights, ind=0):
        """
        Efficient MoE inference without gradient computation.
        
        This method performs expert routing and computation during inference,
        optimizing for memory and computation efficiency.
        
        Args:
            x (torch.Tensor): Input tensor
            flat_expert_indices (torch.Tensor): Flattened expert indices
            flat_expert_weights (torch.Tensor): Flattened expert weights
            ind (int): Modality index for disjoint mode
            
        Returns:
            torch.Tensor: Output tensor after expert processing
        """
        expert_cache = torch.zeros_like(x)
        idxs = flat_expert_indices.argsort()
        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
        token_idxs = idxs // self.num_experts_per_tok
        
        for i, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if i == 0 else tokens_per_expert[i-1]
            if start_idx == end_idx:
                continue
                
            # Select appropriate expert based on fusion type
            if self.fusion_type == "disjoint":
                expert = self.experts[ind][i]
            else:
                expert = self.experts[i]
                
            exp_token_idx = token_idxs[start_idx:end_idx]
            expert_tokens = x[exp_token_idx]
            expert_out = expert(expert_tokens)
            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
            expert_cache.scatter_reduce_(0, exp_token_idx.reshape(-1, 1).repeat(1, x.shape[-1]), 
                                       expert_out, reduce='sum')
        return expert_cache
    












