import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    """ Supervised Contrastive Learning Loss among sample pairs.
    Args:
        scale (float): scaling factor.
    """
    def __init__(self, margin=0.3):
        super().__init__()
        self.m = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, inputs2, targets):
        """
        Args:
            inputs: sample features (before classifier) with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (batch_size)
        """
        # l2-normalize
        inputs = F.normalize(inputs, p=2, dim=1)
        inputs2 = F.normalize(inputs2, p=2, dim=1)
        m = targets.size(0)
        # compute cosine similarity
        dist = 1 - torch.matmul(inputs, inputs.t())
        dist2 = 1 - torch.matmul(inputs, inputs2.t())
        # get mask for pos/neg pairs
        targets= targets.view(-1, 1)
        mask_ap=torch.eye(inputs.size(0)).cuda()
        mask_pos = torch.eq(targets, targets.T).float().cuda()
        mask_neg = mask_pos - torch.eye(inputs.size(0)).cuda()
        dist_an,_ = torch.min((mask_neg * dist + (1 - mask_neg) * 99999999.), dim=1)
        dist_ap,_ = torch.max((mask_ap * dist2), dim=1)
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y) 

        return loss
