import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import logging

from models.config import MiMoEConfig
from models.registry import register_feature_selection_loss


@register_feature_selection_loss("supercon")
class SuperCon(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    
    def __init__(self, config: MiMoEConfig):
        super().__init__()
        init_temp = config.infonce_temperature
        self.log_temp = nn.Parameter(torch.log(torch.tensor(init_temp)))
    
    def _compute_loss(self, _emb1: Tensor, _emb2: Tensor, labels: Tensor) -> Tensor:
        B, T, D = _emb1.shape
        device = _emb1.device
        
        emb1 = F.normalize(_emb1, dim=-1)
        emb2 = F.normalize(_emb2, dim=-1)
        
        temperature = torch.exp(self.log_temp)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        logging.info(f"[Temp] temperature={temperature.item():.6f}")

        anchor_count, contrast_count = T, T
        anchor_feature = torch.cat(torch.unbind(emb1, dim=1), dim=0)
        contrast_feature = torch.cat(torch.unbind(emb2, dim=1), dim=0)
        
        # compute logits
        anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), temperature)
        
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        
        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask), 1,
            torch.arange(B * anchor_count).view(-1, 1).to(device), 0
        )
        mask = mask * logits_mask
        
        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        sum_exp = exp_logits.sum(1, keepdim=True)
        log_prob = logits - torch.log(sum_exp + 1e-9)
        
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
        
        # loss
        loss = - mean_log_prob_pos
        loss = loss.view(anchor_count, B).mean()
        return loss
        
    
    def forward(self, emb1: Tensor, emb2: Tensor=None, labels: Tensor=None) -> Tensor:
        if emb2 is None:
            return self._compute_loss(emb1, emb1, labels)
        return 0.5 * (self._compute_loss(emb1, emb2, labels) + self._compute_loss(emb2, emb1, labels))


@register_feature_selection_loss("cls_supercon")
class CLSSuperCon(SuperCon):
    
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
        init_temp = config.infonce_temperature
        self.log_temp = nn.Parameter(torch.log(torch.tensor(init_temp)))
        
    def forward(self, emb1: Tensor, emb2: Tensor=None, labels: Tensor=None) -> Tensor:
        cls1 = emb1[:, 0].unsqueeze(1) # CLS token [B, 1, D]
        if emb2 is None:
            return self._compute_loss(cls1, cls1, labels)
        cls2 = emb2[:, 0].unsqueeze(1)
        return 0.5 * (self._compute_loss(cls1, cls2, labels) + self._compute_loss(cls2, cls1, labels))


@register_feature_selection_loss("global_supercon")
class GlobalSuperCon(SuperCon):
    
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
        init_temp = config.infonce_temperature
        self.log_temp = nn.Parameter(torch.log(torch.tensor(init_temp)))
    
    def forward(self, emb1: Tensor, emb2: Tensor=None, labels: Tensor=None) -> Tensor:
        global1 = emb1.mean(dim=1, keepdim=True)  # [B, 1, D]
        if emb2 is None:
            return self._compute_loss(global1, global1, labels)
        global2 = emb2.mean(dim=1, keepdim=True)
        return 0.5 * (self._compute_loss(global1, global2, labels) + self._compute_loss(global2, global1, labels))