import math
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

from models.config import MiMoEConfig
from models.registry import register_compactness_loss


# Symmetrized Kullback-Leibler divergence between two isotrophic gaussians
@register_compactness_loss("skl_normal")
class SKLNormalDistributionLoss(nn.Module):
    def __init__(self, config: MiMoEConfig):
        super().__init__()
        
    def _compute_loss(self, emb1: Tensor, emb2: Tensor, labels: Tensor) -> Tensor:
        diff = emb1 - emb2
        squared_diff = torch.sum(diff ** 2, dim=-1) # Sum along the last dimension (assuming mu_p and mu_q are batched)
        return squared_diff.mean() # Take the mean of the KL divergence across the batch and extract the scalar value

    def forward(self, emb1: Tensor, emb2: Tensor=None, targets: Tensor=None) -> Tensor:
        if emb2 is None:
            return self._compute_loss(emb1, emb1, targets)
        return 0.5 * (self._compute_loss(emb1, emb2, targets) + self._compute_loss(emb2, emb1, targets))
    

@register_compactness_loss("skl_vmf")
class SKLVonMisesFisherLoss(SKLNormalDistributionLoss):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
    
    def _compute_loss(self, emb1: Tensor, emb2: Tensor, labels: Tensor) -> Tensor:
        # normalize to unit hypersphere
        emb1 = emb1 / (emb1.norm(dim=-1, keepdim=True) + 1e-8)
        emb2 = emb2 / (emb2.norm(dim=-1, keepdim=True) + 1e-8)
        
        loss_terms = []
        unique_labels = labels.unique()
        
        for y in unique_labels:
            mask = (labels == y)
            if mask.sum() <= 1:
                continue
            
            class_emb1 = emb1[mask]
            class_emb2 = emb2[mask]
            mean_dir = class_emb2.mean(dim=0, keepdim=True)
            mean_dir = mean_dir / (mean_dir.norm(dim=-1, keepdim=True) + 1e-8)
            
            # Alignment with class mean direction (μ(x)^T μ̄_y)
            sim = (class_emb1 * mean_dir).sum(dim=-1)
            class_loss = - sim.mean()
            loss_terms.append(class_loss)
        
        if len(loss_terms) == 0:
            return torch.tensor(0.0, device=emb1.device)
        loss = torch.stack(loss_terms).mean()
        return loss


@register_compactness_loss("skl_cls_vmf")
class SKLCLSVonMisesFisherLoss(SKLNormalDistributionLoss):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
    
    def _compute_loss(self, emb1: Tensor, emb2: Tensor, labels: Tensor) -> Tensor:        
        cls1 = emb1[:, 0].unsqueeze(1) # CLS token [B, 1, D]
        cls2 = emb2[:, 0].unsqueeze(1)
        
        cls1 = cls1 / (cls1.norm(dim=-1, keepdim=True) + 1e-8) # [B, 1, D]
        cls2 = cls2 / (cls2.norm(dim=-1, keepdim=True) + 1e-8) # [B, 1, D]
        
        loss_terms = []
        unique_labels = labels.unique()
        
        for y in unique_labels:
            mask = (labels == y)
            if mask.sum() <= 1:
                continue
            
            class_cls1 = cls1[mask]
            class_cls2 = cls2[mask]
            mean_dir = class_cls2.mean(dim=0, keepdim=True)
            mean_dir = mean_dir / (mean_dir.norm(dim=-1, keepdim=True) + 1e-8)
            
            # Alignment with class mean direction (μ(x)^T μ̄_y)
            sim = (class_cls1 * mean_dir).sum(dim=-1)
            class_loss = - sim.mean()
            loss_terms.append(class_loss)
        
        if len(loss_terms) == 0:
            return torch.tensor(0.0, device=emb1.device)
        loss = torch.stack(loss_terms).mean()
        return loss


# === schedulers === #

class Scheduler:
    def __call__(self, **kwargs):
        raise NotImplemented()


class LinearScheduler(Scheduler):
    def __init__(self, start_value, end_value, n_iterations, start_iteration=0):
        self.start_value = start_value
        self.end_value = end_value
        self.n_iterations = n_iterations
        self.start_iteration = start_iteration
        self.m = (end_value - start_value) / n_iterations

    def __call__(self, iteration):
        if iteration > self.start_iteration + self.n_iterations:
            return self.end_value
        elif iteration <= self.start_iteration:
            return self.start_value
        else:
            return (iteration - self.start_iteration) * self.m + self.start_value


class ExponentialScheduler(LinearScheduler):
    def __init__(self, start_value, end_value, n_iterations, start_iteration=0, base=10):
        self.base = base

        super(ExponentialScheduler, self).__init__(start_value=math.log(start_value, base),
                                                   end_value=math.log(end_value, base),
                                                   n_iterations=n_iterations,
                                                   start_iteration=start_iteration)

    def __call__(self, iteration):
        linear_value = super(ExponentialScheduler, self).__call__(iteration)
        return self.base ** linear_value