import torch
from torch import nn
from torch import Tensor
from typing import List

from models.config import MiMoEConfig
from models.registry import register_auxiliary_loss


@register_auxiliary_loss("importance")
class ImportanceLoss(nn.Module):
    def __init__(self, config: MiMoEConfig):
        super().__init__()
    
    def _compute_loss(self, gate_score: Tensor, gate_logit: Tensor) -> Tensor:
        importance = gate_score.sum(dim=(0, 1)) # [E]
        mean = importance.mean()
        std = importance.std(unbiased=False)
        loss = (std / (mean + 1e-9)) ** 2
        return loss
    
    def forward(self, gate_scores: List[Tensor], gate_logits: List[Tensor]) -> Tensor:
        return sum([self._compute_loss(gate_score, gate_logit) for gate_score, gate_logit in zip(gate_scores, gate_logits)]) / len(gate_scores)
    
    
@register_auxiliary_loss("load")
class LoadLoss(ImportanceLoss):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
        self.moe_top_k = config.granularity
    
    def _compute_loss(self, gate_score: Tensor, gate_logit: Tensor) -> Tensor:
        B, T, E = gate_logit.shape
        logits = gate_logit.view(-1, E)
        noise = torch.randn_like(logits)
        threshold, _ = torch.kthvalue(logits + noise, k= E - self.moe_top_k + 1, dim=-1)
        
        sigma = 1.0 / E
        diff = threshold.unsqueeze(-1) - logits # [B * T, E]
        p_i = 0.5 * torch.erfc(diff / (sigma * (2 ** 0.5))) # [B * T, E]
        
        load = p_i.sum(dim=0)
        loss = (load.std(unbiased=False) / (load.mean() + 1e-9)) ** 2
        return loss
    
    
@register_auxiliary_loss("local_entropy")
class LocalEntropyLoss(ImportanceLoss):  
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
    
    def _compute_loss(self, gate_score: Tensor, gate_logit: Tensor) -> Tensor:
        local_entropy = - (gate_score * torch.log(gate_score + 1e-9)).sum(dim=-1) # [B, T]
        return local_entropy.mean()


@register_auxiliary_loss("global_entropy")
class GlobalEntropyLoss(ImportanceLoss):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
    
    def _compute_loss(self, gate_score: Tensor, gate_logit: Tensor) -> Tensor:
        B, T, E = gate_score.shape
        global_scores = gate_score.mean(dim=(0, 1)) # [E]
        global_entropy = - (global_scores * torch.log(global_scores + 1e-9)).sum()
        return  - global_entropy