from __future__ import print_function

import torch
import torch.nn as nn
import sys
import numpy as np
import torch.nn.functional as F

class AugmentedTripletLoss(nn.Module):
    def __init__(self, margin=1.0, norm=2):
        super(AugmentedTripletLoss, self).__init__()
        self.margin = margin
        self.norm = norm
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets, center):
        # device = (torch.device('cuda')
        #           if inputs.is_cuda
        #           else torch.device('cpu'))
        device = inputs.device
        n = inputs.size(0)  # batch_size

        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability

        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        num_proto = len(center)
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            if dist[i][mask[i] == 0].numel() == 0:
                dist_an.append((dist[i][mask[i]].max()+self.margin).unsqueeze(0))
            else:
                dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        if num_proto > 0:
            # center = torch.stack(center, dim=0)
            # center = center / center.norm(dim=1, keepdim=True).clamp_min(1e-12)
            center = torch.from_numpy(center / np.linalg.norm(center, axis=1)[:, None]).to(device)
            for i in range(n):
                for j in range(num_proto):
                    distp = torch.norm(inputs[i].unsqueeze(0) - center[j], self.norm).clamp(min=1e-12)
                    dist_an[i] = min(dist_an[i].squeeze(0), distp).unsqueeze(0)
        dist_an = torch.cat(dist_an)
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss
    

class CrossEntropyLoss(nn.Module):
    def __init__(self, tau=0.1):
        super(CrossEntropyLoss, self).__init__()
        self.tau = tau

    def forward(self, x, target):
        norms = torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-7
        logit_norm = torch.div(x, norms) / self.tau
        return F.cross_entropy(logit_norm, target)







