import torch
import torch.nn as nn
import torch.nn.functional as F
from .build import LOSS


@LOSS.register_module()
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf """
    def __init__(self, temperature=0.07,
                 base_temperature=0.07,
                 replace_nonfinite=True,
                 replacement_value=0.0):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.replace_nonfinite = replace_nonfinite
        self.replacement_value = replacement_value

    def forward(self, features, labels, mask=None):
        """Compute loss for model. 
        Args:
            features: hidden vector of size [npoints, ...].
            labels: ground truth of shape [npoints].
            mask: contrastive mask of shape [npoints, npoints], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        # labels to mask
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(labels.device)

        # normalize to keep dot-products bounded and compute in fp32 for stability
        features = F.normalize(features.float(), dim=1)

        # compute logits (cosine similarity scaled by temperature)
        logits = torch.matmul(features, features.T) / self.temperature

        # mask-out self-contrast cases
        n = features.size(0)
        logits_mask = torch.ones_like(mask)
        logits_mask.fill_diagonal_(0)
        mask = mask * logits_mask

        # for numerical stability: subtract row-wise max over non-diagonal entries
        logits_mask_bool = logits_mask.bool()
        row_max = logits.masked_fill(~logits_mask_bool, float('-inf')).max(dim=1, keepdim=True).values
        row_max = torch.where(torch.isfinite(row_max), row_max, torch.zeros_like(row_max))
        logits = logits - row_max.detach()

        # compute log_prob with safe denominator to avoid log(0)
        exp_logits = (torch.exp(logits) * logits_mask)
        denom = exp_logits.sum(1, keepdim=True).clamp_min(1e-12)
        log_prob = logits - torch.log(denom)

        # compute mean of log-likelihood over positives; avoid 0/0 and 0*inf
        pos_count = mask.sum(1).clamp_min(1.0)
        mean_log_prob_pos = (mask * log_prob).sum(1) / pos_count

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.mean()

        # replace non-finite loss to avoid breaking training
        if self.replace_nonfinite and not torch.isfinite(loss):
            loss = torch.full_like(loss, self.replacement_value)

        return loss


@LOSS.register_module()
class SelfInfoNCE(nn.Module):
    def __init__(self, temperature=4, replace_nonfinite=True, replacement_value=0.0, **kwargs):
       super(SelfInfoNCE, self).__init__() 
       self.temperature = temperature
       self.cross_entropy = nn.CrossEntropyLoss()
       self.replace_nonfinite = replace_nonfinite
       self.replacement_value = replacement_value
    
    def forward(self, fea_class):
        # compute in fp32 and normalize to control scale
        fea_class = F.normalize(fea_class.float(), dim=1)
        label = torch.arange(fea_class.size(0), device=fea_class.device)
        logits = torch.matmul(fea_class, fea_class.transpose(1,0)) / self.temperature
        loss = self.cross_entropy(logits, label)
        if self.replace_nonfinite and not torch.isfinite(loss):
            loss = torch.full_like(loss, self.replacement_value)
        return loss


@LOSS.register_module()
class PointInfoNCE(nn.Module):
    def __init__(self, temperature=4, replace_nonfinite=True, replacement_value=0.0, **kwargs):
       super(PointInfoNCE, self).__init__() 
       self.temperature = temperature
       self.cross_entropy = nn.CrossEntropyLoss()
       self.replace_nonfinite = replace_nonfinite
       self.replacement_value = replacement_value
    
    def forward(self, feas1, feas2):
        # feas1, feas2: (npoints, feature_dims)
        feas1 = F.normalize(feas1.float(), dim=1)
        feas2 = F.normalize(feas2.float(), dim=1)
        labels = torch.arange(feas1.size(0), device=feas1.device)
        logits = torch.matmul(feas1, feas2.transpose(1,0)) / self.temperature
        loss = self.cross_entropy(logits, labels)
        if self.replace_nonfinite and not torch.isfinite(loss):
            loss = torch.full_like(loss, self.replacement_value)
        return loss
