from __future__ import print_function
import torch
import torch.nn as nn

class InstanceLoss(nn.Module):
    def __init__(self, temperature):
        super(InstanceLoss, self).__init__()
        self.temperature = temperature

        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N))
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        mask = mask.bool()
        return mask

    def forward(self, z_i, z_j):
        self.device = (torch.device('cuda')
                  if z_i.is_cuda
                  else torch.device('cpu'))
        N = 2 * z_i.size(0)#
        z = torch.cat((z_i, z_j), dim=0)
        self.batch_size = z_i.size(0)
        self.mask = self.mask_correlated_samples(self.batch_size).to(self.device)
        sim = torch.matmul(z, z.T) / self.temperature
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask].reshape(N, -1)

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss = loss/N
        return loss

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.01):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):

        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)


        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature)

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask


        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))


        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - mean_log_prob_pos
        loss = loss.mean()
        return loss

#
class SupConLoss_clear(nn.Module):
    def __init__(self, temperature=0.07):
        super(SupConLoss_clear, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):

        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)

        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature)


        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        single_samples = (mask.sum(1) == 0).float()


        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))


        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1)+single_samples)


        loss = - mean_log_prob_pos*(1-single_samples)
        loss = loss.sum()/(loss.shape[0]-single_samples.sum())

        return loss

#
class nce_supervised_hard(nn.Module):
    def __init__(self, temperature, beta, gradient_imp):
        super(nce_supervised_hard, self).__init__()
        self.temperature = temperature
        self.beta = beta
        self.gradient_imp = gradient_imp
    def forward(self, out, label):
        device = (torch.device('cuda')
                  if out.is_cuda
                  else torch.device('cpu'))
        cost = torch.exp(torch.div(
            torch.matmul(out, out.T),
            self.temperature))
        batch = label.shape[0]
        pos_index = torch.zeros((batch, batch)).to(device)
        same_index = torch.eye(batch).to(device)
        for i in range(batch):
            ind = torch.where(label == label[i])[0]
            pos_index[i][ind] = 1
        neg_index = 1 - pos_index
        pos_index = pos_index - same_index
        pos = pos_index * cost
        neg = neg_index * cost

        if self.gradient_imp == False:
            imp = neg_index*(self.beta* neg.log()).exp()
            imp = imp.detach()
        else:
            imp = torch.zeros((batch, batch)).cuda()
            imp_dis = torch.where(neg_index!=0)
            imp[imp_dis[0], imp_dis[1]] = torch.pow(neg[imp_dis[0], imp_dis[1]], self.beta)
        
        neg_exp_sum = (imp*neg).sum(dim = -1) / imp.sum(dim = -1)
        Nce = pos_index * (pos/(pos+(batch - 2)*neg_exp_sum.reshape(-1,1)))
        final_index = torch.where(pos_index!=0)
        Nce = (-torch.log(Nce[final_index[0], final_index[1]])).mean()
        return Nce

class SupConLoss_selective(nn.Module):
    def __init__(self, temperature=0.07, threshold=0.8):
        super(SupConLoss_selective, self).__init__()
        self.temperature = temperature
        self.threshold = threshold

    def forward(self, features, labels):

        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        anchor_dot = torch.matmul(features, features.T)
        anchor_dot_contrast = torch.div(
            anchor_dot,
            self.temperature)
        

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )


        mask = mask * logits_mask
        negative_mask = logits_mask - mask
        negative_mask.to(device)
        one_tensor = torch.tensor(1.).to(device)

        for i in range(batch_size):
            for j in range(batch_size):
                if(mask[i][j].equal(one_tensor) and anchor_dot[i][j] > self.threshold):
                    mask[i][j] = 0
                    logits_mask[i][j] = 0
                if(negative_mask[i][j].equal(one_tensor) and anchor_dot[i][j] < (1-self.threshold)):
                    negative_mask[i][j] = 0
                    logits_mask[i][j] = 0

        single_samples = (mask.sum(1) == 0).float()

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1)+single_samples)

        
        loss = - mean_log_prob_pos * (1-single_samples)
        if(loss.shape[0]!=single_samples.sum()):
            loss = loss.sum()/(loss.shape[0]-single_samples.sum())
        else:
            loss = loss.sum()

        return loss

class SupConLoss_hard(nn.Module):
    def __init__(self, temperature=0.07, threshold1=0.3,threshold2=0.3):
        super(SupConLoss_hard, self).__init__()
        self.temperature = temperature
        self.threshold1 = threshold1
        self.threshold2 = threshold2

    def forward(self, features, labels):

        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        anchor_dot = torch.matmul(features, features.T)
        anchor_dot_contrast = torch.div(
            anchor_dot,
            self.temperature)
        

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )


        mask = mask * logits_mask
        negative_mask = logits_mask - mask
        negative_mask.to(device)
        one_tensor = torch.tensor(1.).to(device)

        anchor_dot_mask = mask * anchor_dot
        anchor_dot1 = -torch.ones_like(anchor_dot)  
        anchor_dot_mask = torch.where(anchor_dot_mask==0, anchor_dot1, anchor_dot_mask)
        mask_nums = ((mask.sum(1)) * self.threshold1).int()
        _, indices = torch.sort(anchor_dot_mask,descending = True)
        for i in range(batch_size):
            if(mask_nums[i]>0):
                mask[i][indices[i,:mask_nums[i]]]=0

        anchor_dot_negative = negative_mask * anchor_dot
        anchor_dot2 = torch.ones_like(anchor_dot)
        anchor_dot_negative = torch.where(anchor_dot_negative==0, anchor_dot2, anchor_dot_negative)
        negative_nums = ((negative_mask.sum(1)) * self.threshold2).int()

        _, indices = torch.sort(anchor_dot_negative)
        for i in range(batch_size):
            negative_mask[i][indices[i,:negative_nums[i]]]=0
        logits_mask = mask + negative_mask

        single_samples = (mask.sum(1) == 0).float()

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))


        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1)+single_samples)

        
        loss = - mean_log_prob_pos * (1-single_samples)
        if(loss.shape[0]!=single_samples.sum()):
            loss = loss.sum()/(loss.shape[0]-single_samples.sum())
        else:
            loss = loss.sum()

        return loss
