import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SupPCLoss(nn.Module):
    '''
    combination of supervised and proxy contrastive loss.
    loss = (sample positive pair + proxy positive pair) / (sample pair + proxy pair)
    '''
    def __init__(self, args, temperature=0.07):
        super(SupPCLoss, self).__init__()
        self.args = args
        self.temperature = temperature
        self.class_label = torch.LongTensor([i for i in range(args.num_classes)]).cuda()

    def forward(self, features, proxy, labels):
        '''
        features: (N, dim) or (2N, dim) if forAug
        proxy: (C, dim)
        labels: (N) or (2N) if forAug
        '''
        features = F.normalize(features, dim=1)
        proxy = F.normalize(proxy, dim=1)
        batch_size = features.size(0)

        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().cuda()

        # feature similartiy
        features_dot = torch.div(torch.matmul(features, features.T), self.temperature) # (N,N)
        # for numerical stability
        logits_max, _ = torch.max(features_dot, dim=1, keepdim=True)
        features_logits = features_dot - logits_max.detach() # (N,N)

        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).cuda(),
            0
        ) 
        mask = mask * logits_mask  # mask for all positive sample-sample pair

        # features and proxy similarity
        features_dot_proxy = torch.div(torch.matmul(features, proxy.T), self.temperature)  # (N,C)
        mask_proxy = (labels==self.class_label.unsqueeze(0)).float().cuda()  # (N,C) # mask for all positive sample-proxy pair

        # sum of all pairs including sample-sample and sample-proxy
        exp_all = (torch.exp(features_logits)*logits_mask).sum(1, keepdim=True) + features_dot_proxy.sum(1, keepdim=True) # (N,1)

        # used for sum of positive pairs including sample-sample and sample-proxy
        exp_positive = torch.exp(features_logits) + torch.exp((features_dot_proxy*mask_proxy).sum(1, keepdim=True))  # (N, N)

        #log_prob: log(positive pair / all pair)
        log_prob = (torch.log(exp_positive) - torch.log(exp_all))*mask  # (N, N), only position of positive pair have values, others are 0.

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = log_prob.sum(1) / mask.sum(1)  # (N, 1)

        # loss
        loss = mean_log_prob_pos.mean()
        return loss


        
        