import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
import logging
from models.config import MiMoEConfig
from models.registry import register_pretrain_loss

from deprecated import deprecated


@register_pretrain_loss("instance_infonce")
class InstanceLevelInfoNCELoss(nn.Module):
    def __init__(self, config: MiMoEConfig):
        super().__init__()
        init_temp = config.infonce_temperature
        self.log_temp = nn.Parameter(torch.log(torch.tensor(init_temp)))
        self.proj = nn.Linear(config.hidden_dim, config.critic_dim)
    
    def _compute_loss(self, emb1: Tensor, emb2: Tensor) -> Tensor:
        B, T, D = emb1.shape
        
        feature1 = self.proj(emb1).reshape(B * T, -1)  # Exclude CLS
        feature2 = self.proj(emb2).reshape(B * T, -1)
        feature1 = F.normalize(feature1, dim=-1)
        feature2 = F.normalize(feature2, dim=-1)
        
        temperature = torch.exp(self.log_temp)
        sim_matrix = torch.matmul(feature1, feature2.T) / temperature  # [B * T, B * T]
        logging.info(f"temperature: {temperature.item():.4f}")
        
        pos = torch.exp(sim_matrix.diagonal()).view(B * T, 1)
        neg_mask = ~ torch.eye(B * T, device=feature1.device, dtype=torch.bool)
        neg_sum = torch.sum(torch.exp(sim_matrix[neg_mask].view(B * T, -1)), dim=-1)
        loss = - torch.log(pos / (pos + neg_sum + 1e-9))
        return loss.mean()
    
    def forward(self, emb1: Tensor, emb2: Tensor=None, targets: Tensor=None) -> Tensor:
        if emb2 is None:
            return self._compute_loss(emb1, emb1)
        return 0.5 * (self._compute_loss(emb1, emb2) + self._compute_loss(emb2, emb1))


@register_pretrain_loss("cls_infonce")
class CLSInfoNCELoss(InstanceLevelInfoNCELoss):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
    
    def _compute_loss(self, emb1: Tensor, emb2: Tensor) -> Tensor:
        B, D = emb1.shape
        feature1 = F.normalize(emb1, dim=-1)  # [B, D]
        feature2 = F.normalize(emb2, dim=-1)
        
        temperature = torch.exp(self.log_temp)
        sim_matrix = torch.matmul(feature1, feature2.T) / temperature  # [B, B]
        logging.info(f"temperature: {temperature.item():.4f}")
        
        pos = torch.exp(sim_matrix.diag())
        neg_mask = ~ torch.eye(B, device=feature1.device, dtype=torch.bool)
        neg_sum = torch.sum(torch.exp(sim_matrix[neg_mask].view(B, -1)), dim=-1)
        loss = - torch.log(pos / (pos + neg_sum + 1e-9))
        return loss.mean()

    def forward(self, emb1: Tensor, emb2: Tensor=None, targets: Tensor=None) -> Tensor:
        cls1 = emb1[:, 0]  # CLS token
        if emb2 is None:
            return self._compute_loss(cls1, cls1)
        cls2 = emb2[:, 0]
        return 0.5 * (self._compute_loss(cls1, cls2) + self._compute_loss(cls2, cls1))

        
@register_pretrain_loss("global_avg_infonce")
class GlobalAveragedInfoNCELoss(CLSInfoNCELoss): 
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
    
    def forward(self, emb1: Tensor, emb2: Tensor=None, targets: Tensor=None) -> Tensor:
        global1 = emb1.mean(dim=1)
        if emb2 is None:
            return self._compute_loss(global1, global1)
        global2 = emb2.mean(dim=1)
        return 0.5 * (self._compute_loss(global1, global2) + self._compute_loss(global2, global1))