import torch
import torch.nn.functional as F
from torch import nn
from losses.gather import GatherLayer
from losses.cross_entry_smooth import CrossEntropyWithLabelSmooth

class TripletLoss(nn.Module):
    def __init__(self, margin=0.5):
        super().__init__()
        self.m = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets):
        inputs = F.normalize(inputs, p=2, dim=1)
        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)

        dist = 1 - torch.matmul(inputs, gallery_inputs.t()) 
        targets, gallery_targets = targets.view(-1,1), gallery_targets.view(-1,1)
        mask_pos = torch.eq(targets, gallery_targets.T).float().cuda()
        mask_neg = 1 - mask_pos
        # For each anchor, find the hardest positive and negative pairs
        dist_ap, _ = torch.max((dist - mask_neg * 99999999.), dim=1)
        dist_an, _ = torch.min((dist + mask_pos * 99999999.), dim=1)
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)

        return loss

class ContrastiveLoss4(nn.Module):
    def __init__(self, num_pids, feat_dim, margin, momentum, scale, epsilon):
        super().__init__()
        self.m = margin
        self.num_pids = num_pids
        self.feat_dim = feat_dim
        self.momentum = momentum
        self.epsilon = epsilon
        self.scale = scale

        self.register_buffer('feature_memory', torch.zeros((num_pids, feat_dim)))
        self.register_buffer('label_memory', torch.zeros(num_pids, dtype=torch.int64) - 1)
        self.has_been_filled = False
        self.ranking_loss = CrossEntropyWithLabelSmooth()

    def forward(self, inputs, targets):

        
        # gather all samples from different GPUs as gallery to compute pairwise loss.
        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
        self._update_memory(gallery_inputs.detach(), gallery_targets)
        # l2-normlize
        inputs = F.normalize(inputs, p=2, dim=1)
        memory_norm = F.normalize(self.feature_memory.detach(), p=2, dim=1).cuda()
        dist2 =  torch.matmul(inputs, memory_norm.t()) * self.scale

        # get positive and negative masks
        mask_identity = torch.zeros(targets.size(0), self.num_pids).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()
        if not self.has_been_filled:
            invalid_index = self.label_memory == -1
            if sum(invalid_index.type(torch.int)) == 0:
                self.has_been_filled = True
                print('Memory bank is full')
            else: return torch.tensor(0)
                
        loss = self.ranking_loss(dist2,targets)

        return loss
    
    def _update_memory(self, features, labels):
        label_to_feat = {}
        for x, y in zip(features, labels):
            if y not in label_to_feat:
                label_to_feat[y] = [x.unsqueeze(0)]
            else:
                label_to_feat[y].append(x.unsqueeze(0))
        if not self.has_been_filled:
            for y in label_to_feat:
                feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0).to(self.feature_memory[y].device)
                self.feature_memory[y] = feat
                self.label_memory[y] = y
        else:
            for y in label_to_feat:
                feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0).to(self.feature_memory[y].device)
                self.feature_memory[y] = self.momentum * self.feature_memory[y] + (1. - self.momentum) * feat