import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
from pytorch_metric_learning import miners, losses
from scipy.spatial import distance
import numpy as np

def binarize(T, nb_classes):
    T = T.cpu().numpy()
    import sklearn.preprocessing
    T = sklearn.preprocessing.label_binarize(
        T, classes = range(0, nb_classes)
    )
    T = torch.FloatTensor(T).cuda()
    return T

def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)
    return output

class Proxy_Anchor(torch.nn.Module):
    def __init__(self, nb_classes, sz_embed, mrg = 0.1, alpha = 32):
        torch.nn.Module.__init__(self)
        # Proxy Anchor Initialization
        self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda())
        nn.init.kaiming_normal_(self.proxies, mode='fan_out')

        self.nb_classes = nb_classes
        self.sz_embed = sz_embed
        self.mrg = mrg
        self.alpha = alpha
        
    def forward(self, X, T):
        P = self.proxies

        cos = F.linear(l2_norm(X), l2_norm(P))  # Calcluate cosine similarity
        P_one_hot = binarize(T = T, nb_classes = self.nb_classes)
        N_one_hot = 1 - P_one_hot
    
        pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
        neg_exp = torch.exp(self.alpha * (cos + self.mrg))

        with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1)   # The set of positive proxies of data in the batch
        num_valid_proxies = len(with_pos_proxies)   # The number of positive proxies
        
        P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0) 
        N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)
        
        pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
        neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
        loss = pos_term + neg_term     
        
        return loss

class NCAloss(nn.Module):
    """
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 partitionings,
                 proxies,
                 nb_classes,
                 num_partitionings,
                 sz_embed,
                 temperature=0.05):
        super(NCAloss, self).__init__()

        self.partitionings = partitionings # num_partitionings X nb_classes
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed
        self.nb_classes = nb_classes
        self.proxies = proxies # num_partitionings x num_partitions x sz_embed
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        #nn.init.xavier_uniform_(self.proxies)
        #nn.init.kaiming_uniform_(self.proxies)

        self.temperature = temperature
        self.global_proxies = None
        #self.loss_fn = nn.CrossEntropyLoss()

    def loss_fn(self, logits, instance_targets):
        if len(instance_targets.size()) > 1:
            intance_targets = instance_targets.float() / instance_targets.sum(dim=1).unsqueeze(1)
            loss = -1 * F.log_softmax(logits, dim=1) * intance_targets
            return loss.sum(1).mean()
        else:
            return F.cross_entropy(logits, instance_targets)


    def forward(self, embeddings, instance_targets=None, return_prob=False):
        gt_codewords = self.partitionings # nb_classes X num_partitionings

        global_proxies = []
        for i in range(self.nb_classes):
            cls_proxy = self.proxies[torch.arange(self.num_partitionings), gt_codewords[i]]
            global_proxies.append(cls_proxy)
        self.global_proxies = torch.stack(global_proxies).view(self.nb_classes, self.num_partitionings * self.sz_embed)

        norm_weight = nn.functional.normalize(self.global_proxies, dim=1)
        norm_embeddings = F.normalize(embeddings, dim=1)

        prediction_logits = nn.functional.linear(norm_embeddings, norm_weight)

        P_one_hot = binarize(T=instance_targets, nb_classes=self.nb_classes)
        N_one_hot = 1 - P_one_hot

        pos_exp = torch.exp(- (prediction_logits / self.temperature))
        neg_exp = torch.exp( (prediction_logits / self.temperature))



        if return_prob:
            return F.softmax(prediction_logits / self.temperature, dim=1)

        loss = self.loss_fn(prediction_logits / self.temperature, instance_targets)
        return loss


class NormSoftmaxLoss(nn.Module):
    """
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 partitionings,
                 proxies,
                 nb_classes,
                 num_partitionings,
                 sz_embed,
                 temperature=0.05):
        super(NormSoftmaxLoss, self).__init__()

        self.partitionings = partitionings # num_partitionings X nb_classes
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed
        self.nb_classes = nb_classes
        self.proxies = proxies # num_partitionings x num_partitions x sz_embed
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        #nn.init.xavier_uniform_(self.proxies)
        #nn.init.kaiming_uniform_(self.proxies)

        self.temperature = temperature
        self.global_proxies = None
        #self.loss_fn = nn.CrossEntropyLoss()

    def loss_fn(self, logits, instance_targets):
        if len(instance_targets.size()) > 1:
            intance_targets = instance_targets.float() / instance_targets.sum(dim=1).unsqueeze(1)
            loss = -1 * F.log_softmax(logits, dim=1) * intance_targets
            return loss.sum(1).mean()
        elif len(instance_targets.size()) == 0:
            instance_targets = torch.LongTensor([instance_targets.data, ]).cuda()
            return F.cross_entropy(logits, instance_targets)
        else:
            return F.cross_entropy(logits, instance_targets)


    def forward(self, embeddings, instance_targets=None, return_prob=False):
        gt_codewords = self.partitionings # nb_classes X num_partitionings

        global_proxies = []
        for i in range(self.nb_classes):
            cls_proxy = self.proxies[torch.arange(self.num_partitionings), gt_codewords[i]]
            global_proxies.append(cls_proxy)
        self.global_proxies = torch.stack(global_proxies).view(self.nb_classes, self.num_partitionings * self.sz_embed)

        norm_weight = nn.functional.normalize(self.global_proxies, dim=1)
        norm_embeddings = F.normalize(embeddings, dim=1)

        prediction_logits = nn.functional.linear(norm_embeddings, norm_weight)

        if return_prob:
            return F.softmax(prediction_logits / self.temperature, dim=1)

        loss = self.loss_fn(prediction_logits / self.temperature, instance_targets)
        return loss

class N_PQ(nn.Module):
    """
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_classes,
                 num_partitionings,
                 sz_embed,
                 temperature=0.05):
        super(N_PQ, self).__init__()

        self.proxies = torch.nn.Parameter(torch.Tensor(nb_classes, sz_embed*num_partitionings).cuda())
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        nn.init.xavier_uniform_(self.proxies)

        self.temperature = temperature
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, embeddings, instance_targets):
        norm_weight = nn.functional.normalize(self.proxies, dim=1)
        norm_embeddings = F.normalize(embeddings, dim=1)

        prediction_logits = nn.functional.linear(norm_embeddings, norm_weight)

        loss = self.loss_fn(prediction_logits / self.temperature, instance_targets)
        return loss

class NormSoftmaxLossDisjoint(nn.Module):
    """
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_classes,
                 num_partitionings,
                 sz_embed,
                 temperature=0.0625):
        super(NormSoftmaxLossDisjoint, self).__init__()
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed

        self.proxies = torch.nn.Parameter(torch.Tensor(nb_classes, num_partitionings * sz_embed).cuda())
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        #nn.init.kaiming_uniform_(self.proxies, a=math.sqrt(5))
        nn.init.xavier_uniform_(self.proxies)

        self.temperature = temperature
        self.loss_fn = nn.CrossEntropyLoss()

    # def forward(self, embeddings, instance_targets):
    #     proxies = self.proxies.view(-1, self.num_partitionings, self.sz_embed)
    #     embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
    #     norm_weight = nn.functional.normalize(proxies, dim=2)
    #
    #     sub_losses = []
    #     for i in range(self.num_partitionings):
    #         sub_logits = nn.functional.linear(embeddings[:, i, :], norm_weight[:, i, :])
    #         sub_loss = self.loss_fn(sub_logits / self.temperature, instance_targets)
    #         sub_losses.append(sub_loss)
    #
    #     sub_losses = torch.stack(sub_losses)
    #     return sub_losses.mean()
    def forward(self, embeddings, instance_targets):
        proxies = self.proxies.view(-1, self.num_partitionings, self.sz_embed)
        embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
        norm_weight = nn.functional.normalize(proxies, dim=2)

        sub_losses = []
        logits = []
        for i in range(self.num_partitionings):
            sub_logits = nn.functional.linear(embeddings[:, i, :], norm_weight[:, i, :]).unsqueeze(-1)
            logits.append(sub_logits)
            #sub_loss = self.loss_fn(sub_logits / self.temperature, instance_targets)
            #sub_losses.append(sub_loss)
        logits = torch.cat(logits, dim=2).mean(2)
        loss = self.loss_fn(logits / self.temperature, instance_targets)
        return loss


class CombiNormSoftmaxLoss(nn.Module):
    """
    Combinatorial metric learning
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 proxies,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 temperature=0.05):
        super(CombiNormSoftmaxLoss, self).__init__()

        self.num_partitions = num_partitions
        self.num_partitionings = num_partitionings
        self.num_classes = nb_classes
        self.proxies = proxies
        self.temperature = temperature
        self.loss_fn = nn.NLLLoss()

        self.register_buffer('partitionings', -torch.ones(num_partitionings, nb_classes).long())

    def set_partitionings(self, partitionings_map):
        self.partitionings.copy_(torch.LongTensor(partitionings_map).t())
        arange = torch.arange(self.num_partitionings).view(-1, 1).type_as(self.partitionings)
        #arange를 더해준다.? -> 01110, 23332
        self.partitionings.add_(arange * self.num_partitions)

    def rescale_grad(self):
        self.proxies.grad.mul_(self.num_partitionings)

    def forward(self, embeddings, instance_targets, return_meta_dist=False):
        assert self.partitionings.sum() > 0, 'Partitionings is never given to the module.'
        norm_weight = nn.functional.normalize(self.proxies, dim=2)

        #prediction_logits = nn.functional.linear(embeddings, norm_weight)
        prediction_logits = torch.einsum('ab,bcd->acd', (embeddings, norm_weight.permute(2, 0, 1)))

        prediction_logits = prediction_logits.view(-1, self.num_partitionings, self.num_partitions)
        prediction_logits = F.log_softmax(prediction_logits / self.temperature, dim=2)

        if return_meta_dist:
            return prediction_logits

        prediction_logits = prediction_logits.view(-1, self.num_partitionings * self.num_partitions)
        output = prediction_logits.index_select(1, self.partitionings.view(-1))
        output = output.view(-1, self.num_partitionings, self.num_classes)

        output = output.sum(1)

        loss = self.loss_fn(output, instance_targets) / self.num_partitionings

        return loss

class CombiNormSoftmaxLossDisjoint(nn.Module):
    """
    Combinatorial metric learning
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 proxies,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 temperature=0.05):
        super(CombiNormSoftmaxLossDisjoint, self).__init__()

        self.num_partitions = num_partitions
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed
        self.num_classes = nb_classes
        self.proxies = proxies
        self.temperature = temperature
        #self.loss_fn = nn.NLLLoss()

        self.register_buffer('partitionings', -torch.ones(num_partitionings, nb_classes).long())

    def set_partitionings(self, partitionings_map):
        self.partitionings.copy_(torch.LongTensor(partitionings_map).t())
        self.orig_partitionings = torch.LongTensor(partitionings_map).t()
        arange = torch.arange(self.num_partitionings).view(-1, 1).type_as(self.partitionings)
        #arange를 더해준다.? -> 01110, 23332
        self.partitionings.add_(arange * self.num_partitions)

    def rescale_grad(self):
        self.proxies.grad.mul_(self.num_partitionings)

    def loss_fn(self, logits, instance_targets):
        if len(instance_targets.size()) > 1:
            intance_targets = instance_targets.float() / instance_targets.sum(dim=1).unsqueeze(1)
            loss = -1 * logits * intance_targets
            return loss.sum(1).mean()
        elif len(instance_targets.size()) == 0:
            instance_targets = torch.LongTensor([instance_targets.data, ]).cuda()
            return F.nll_loss(logits, instance_targets)
        else:
            return F.nll_loss(logits, instance_targets)

    def forward(self, embeddings, instance_targets, return_meta_dist=False, return_dist=False):
        assert self.partitionings.sum() > 0, 'Partitionings is never given to the module.'
        #norm_weight = nn.functional.normalize(self.proxies, dim=2)
        embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
        #embeddings = F.normalize(embeddings, p=2, dim=2)
        #prediction_logits = nn.functional.linear(embeddings, norm_weight)
        sub_meta_logits_lst = []
        for i in range(self.num_partitionings):
            sub_meta_logits = F.linear(embeddings[:, i, :], self.proxies[i, : , :]).unsqueeze(1)
            sub_meta_logits_lst.append(sub_meta_logits)
        prediction_logits = torch.cat(sub_meta_logits_lst, dim=1)
        prediction_logits = prediction_logits.view(-1, self.num_partitionings, self.num_partitions)
        prediction_logits = F.log_softmax(prediction_logits / self.temperature, dim=2)

        if return_meta_dist:
            return prediction_logits

        prediction_logits = prediction_logits.view(-1, self.num_partitionings * self.num_partitions)
        output = prediction_logits.index_select(1, self.partitionings.view(-1))
        output = output.view(-1, self.num_partitionings, self.num_classes)

        output = output.sum(1)
        if return_dist:
            return output
        loss = self.loss_fn(output, instance_targets) / self.num_partitionings

        return loss

class CombiNormSoftmaxLossWithLinear(nn.Module):
    """
    Combinatorial metric learning with additional linear classifier on original class space
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 meta_temperature=0.05,
                 orig_temperature=0.01,
                 alpha=1.0):
        super(CombiNormSoftmaxLossWithLinear, self).__init__()
        self.proxies = torch.nn.Parameter(torch.Tensor(num_partitionings, num_partitions, sz_embed))
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        stdv = 1. / math.sqrt(self.proxies.size(2))
        self.proxies.data.uniform_(-stdv, stdv)

        self.normsoftmax = NormSoftmaxLoss(nb_classes=nb_classes, sz_embed=sz_embed, temperature=orig_temperature)
        self.combinatorial_classifiers = CombiNormSoftmaxLoss(proxies=self.proxies, nb_classes=nb_classes, sz_embed=sz_embed, num_partitions=num_partitions,
                                                              num_partitionings=num_partitionings, temperature=meta_temperature)
        self.alpha = alpha

    def forward(self, embeddings, instance_targets, return_meta_dist=False):
        orig_loss = self.normsoftmax(embeddings, instance_targets)
        meta_loss = self.combinatorial_classifiers(embeddings, instance_targets, return_meta_dist=return_meta_dist)
        # loss = meta_loss + self.alpha * orig_loss
        # return loss
        return (meta_loss, orig_loss)

class CombiNormSoftmaxLossWithLinearDisjoint(nn.Module):
    """
    Combinatorial metric learning with additional linear classifier on original class space
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_global_proxies,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 partitionings,
                 meta_temperature=0.05,
                 orig_temperature=0.01,
                 zeta=20.0,
                 is_norm=True,
                 metric_mode='NormSoftmax',
                 NM_weight=1,
                 NM_train=False,
                 k=3):
        super(CombiNormSoftmaxLossWithLinearDisjoint, self).__init__()
        self.nb_global_proxies = nb_global_proxies
        self.is_norm = is_norm
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed
        self.proxies = torch.nn.Parameter(torch.Tensor(num_partitionings, num_partitions, sz_embed))
        self.nb_classes = nb_classes
        self.nm_weight = NM_weight
        self.partitionings = partitionings
        if NM_train:
            self.k = torch.nn.Parameter(torch.Tensor(1))
            self.k.data.fill_(nb_classes)
        else:
            self.k = k

        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        stdv = 1. / math.sqrt(self.proxies.size(2))
        #self.proxies.data.uniform_(-stdv, stdv)
        nn.init.xavier_uniform_(self.proxies)
        if metric_mode == 'NormSoftmax':
            self.normsoftmax = NormSoftmaxLoss(partitionings=partitionings, proxies=self.proxies, nb_classes=nb_global_proxies, num_partitionings=num_partitionings, sz_embed=sz_embed, temperature=orig_temperature)
        elif metric_mode == 'Proxy_Anchor':
            self.normsoftmax = Proxy_Anchor(nb_classes=nb_global_proxies, sz_embed= self.sz_embed * self.num_partitionings, mrg = 0.1, alpha = 32)
        else:
            assert False

        print('Metric mode is ', metric_mode)

        self.combinatorial_classifiers = CombiNormSoftmaxLossDisjoint(proxies=self.proxies, nb_classes=nb_classes, sz_embed=sz_embed, num_partitions=num_partitions,
                                                              num_partitionings=num_partitionings, temperature=meta_temperature)
        self.combinatorial_classifiers.set_partitionings(self.partitionings)
        self.zeta = zeta

    def forward(self, embeddings, instance_targets, return_meta_dist=False):
        embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
        if self.is_norm:
            embeddings = F.normalize(embeddings, p=2, dim=2)
            self.proxies.data = F.normalize(self.proxies.data, p=2, dim=2)

        #orig_loss = self.normsoftmax(embeddings, instance_targets)
        meta_loss = self.combinatorial_classifiers(embeddings, instance_targets, return_meta_dist=return_meta_dist)

        descriptor = self.SoftAssignment(self.proxies, embeddings, self.num_partitionings, self.zeta)
        orig_loss = self.normsoftmax(descriptor, instance_targets)
        global_proxies = self.normsoftmax.global_proxies.view(-1, self.num_partitionings * self.sz_embed)
        nm_loss = NM_loss(global_proxies, k=self.k)
        # loss = meta_loss + self.alpha * orig_loss
        # return loss
        return (meta_loss, orig_loss, nm_loss)

    def SoftAssignment(self, z, x, num_partitionings, zeta):
        z = z
        x = x.view(-1, self.num_partitionings, self.sz_embed)

        soft_desc = []
        for i in range(num_partitionings):
            size_x = x[:, i, :].size(0)
            size_z = z[i, :, :].size(0)
            xx = x[:, i, :]
            #xx = xx.repeat(1, 1, size_z)

            zz = z[i, :, :]

            #zz = zz.repeat(1, 1, size_x)
            #zz = zz.permute(2, 1, 0)
            sub_logits = F.linear(xx, zz)
            # B X num_partitions
            #logits.append(sub_logits)
            sub_logits = F.softmax(sub_logits * zeta, dim=1)
            sub_soft_desc = torch.matmul(sub_logits, z[i, :, :]).unsqueeze(1)
            soft_desc.append(sub_soft_desc)
        soft_desc = torch.cat(soft_desc, dim=1)
        soft_desc = F.normalize(soft_desc, p=2, dim=2)
        soft_desc = soft_desc.view(-1, self.num_partitionings * self.sz_embed)
        return soft_desc




# class CombiNormSoftmaxLossWithLinear(nn.Module):
#     """
#     Combinatorial metric learning with additional linear classifier on original class space
#     L2 normalize weights and apply temperature scaling on logits.
#     """
#     def __init__(self,
#                  nb_classes,
#                  sz_embed,
#                  num_partitions,
#                  num_partitionings,
#                  meta_temperature=0.05,
#                  orig_temperature=0.01):
#         super(CombiNormSoftmaxLossWithLinear, self).__init__()
#
#         self.num_partitions = num_partitions
#         self.num_partitionings = num_partitionings
#         self.num_classes = nb_classes
#
#         # Meta proxies for combinatorial learning
#         self.meta_proxies = torch.nn.Parameter(torch.Tensor(num_partitions * num_partitionings, sz_embed).cuda())
#         # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
#         stdv = 1. / math.sqrt(self.meta_proxies.size(1))
#         self.meta_proxies.data.uniform_(-stdv, stdv)
#
#         self.meta_temperature = meta_temperature
#         self.meta_loss_fn = nn.NLLLoss()
#
#         self.register_buffer('partitionings', -torch.ones(num_partitionings, nb_classes).long())
#
#         # additional proxy on original class space
#         self.orig_proxies = torch.nn.Parameter(torch.Tensor(nb_classes, sz_embed).cuda())
#         stdv = 1. / math.sqrt(self.orig_proxies.size(1))
#         self.orig_proxies.data.uniform_(-stdv, stdv)
#
#         self.orig_temperature = orig_temperature
#         self.orig_loss_fn = nn.CrossEntropyLoss()
#
#     def set_partitionings(self, partitionings_map):
#         self.partitionings.copy_(torch.LongTensor(partitionings_map).t())
#         arange = torch.arange(self.num_partitionings).view(-1, 1).type_as(self.partitionings)
#         #arange를 더해준다.? -> 01110, 23332
#         self.partitionings.add_(arange * self.num_partitions)
#
#     def rescale_grad(self, total_num_partitionings=None):
#         # for params in self.proxies.parameters():
#         #     if total_num_partitionings is None:
#         #         params.grad.mul_(self.num_partitionings)
#         #     else:
#         #         params.grad.mul_(total_num_partitionings)
#         self.proxies.grad.mul_(self.num_partitionings)
#
#     def forward(self, embeddings, instance_targets, return_meta_dist=False):
#         assert self.partitionings.sum() > 0, 'Partitionings is never given to the module.'
#
#         # normalize weight
#         norm_meta_weight = nn.functional.normalize(self.meta_proxies, dim=1)
#         norm_orig_weight = nn.functional.normalize(self.orig_proxies, dim=1)
#
#         # combinatorial logit with meta_weights
#         prediction_logits = nn.functional.linear(embeddings, norm_meta_weight)
#         prediction_logits = prediction_logits.view(-1, self.num_partitionings, self.num_partitions)
#         prediction_logits = F.log_softmax(prediction_logits / self.temperature, dim=2)
#
#         if return_meta_dist:
#             return prediction_logits
#
#         prediction_logits = prediction_logits.view(-1, self.num_partitionings * self.num_partitions)
#         output = prediction_logits.index_select(1, self.partitionings.view(-1))
#         output = output.view(-1, self.num_partitionings, self.num_classes)
#
#         meta_output = output.sum(1)
#
#         meta_loss = self.loss_fn(meta_output, instance_targets) / self.num_partitionings
#
#         # Orignal proxy loss
#
#
#         return loss



# We use PyTorch Metric Learning library for the following codes.
# Please refer to "https://github.com/KevinMusgrave/pytorch-metric-learning" for details.
class Proxy_NCA(torch.nn.Module):
    def __init__(self, nb_classes, sz_embed, scale=32):
        super(Proxy_NCA, self).__init__()
        self.nb_classes = nb_classes
        self.sz_embed = sz_embed
        self.scale = scale
        self.loss_func = losses.ProxyNCALoss(num_classes = self.nb_classes, embedding_size = self.sz_embed, softmax_scale = self.scale).cuda()

    def forward(self, embeddings, labels):
        loss = self.loss_func(embeddings, labels)
        return loss
    
class MultiSimilarityLoss(torch.nn.Module):
    def __init__(self, ):
        super(MultiSimilarityLoss, self).__init__()
        self.thresh = 0.5
        self.epsilon = 0.1
        self.scale_pos = 2
        self.scale_neg = 50
        
        self.miner = miners.MultiSimilarityMiner(epsilon=self.epsilon)
        self.loss_func = losses.MultiSimilarityLoss(self.scale_pos, self.scale_neg, self.thresh)
        
    def forward(self, embeddings, labels):
        hard_pairs = self.miner(embeddings, labels)
        loss = self.loss_func(embeddings, labels, hard_pairs)
        return loss
    
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5, **kwargs):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.loss_func = losses.ContrastiveLoss(neg_margin=self.margin) 
        
    def forward(self, embeddings, labels):
        loss = self.loss_func(embeddings, labels)
        return loss
    
class TripletLoss(nn.Module):
    def __init__(self, margin=0.1, **kwargs):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.miner = miners.TripletMarginMiner(margin, type_of_triplets = 'semihard')
        self.loss_func = losses.TripletMarginLoss(margin = self.margin)
        
    def forward(self, embeddings, labels):
        hard_pairs = self.miner(embeddings, labels)
        loss = self.loss_func(embeddings, labels, hard_pairs)
        return loss
    
class NPairLoss(nn.Module):
    def __init__(self, l2_reg=0):
        super(NPairLoss, self).__init__()
        self.l2_reg = l2_reg
        self.loss_func = losses.NPairsLoss(l2_reg_weight=self.l2_reg, normalize_embeddings=False)
        
    def forward(self, embeddings, labels):
        loss = self.loss_func(embeddings, labels)
        return loss




def PseudoCodewordLoss(x ,y, reduction='mean'):
    return F.nll_loss(x.transpose(1, 2), y, reduction=reduction)

def PseudoCodewordLoss_l2(x,y, reduction='mean'):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)

    return (2 - 2 * (x * y).sum(dim=-1)).mean()

def PseudoCodewordLoss_KLDIV(x, y):
    return nn.KLDivLoss(x, y)


def PseudoCodewordLossSoft(x,y, reduction='mean'):
    if reduction == 'mean':
        loss = - (x * y).mean()
    elif reduction == 'none':
        loss = - (x * y)
    else:
        raise NameError('reduction should be mean or none')
    return loss

def loss_fn(logit_mat, label_mat, mask=None):
    if mask is not None:
        eps = 1e-12
        logits = logit_mat * mask
        logits = torch.exp(logits)
        denominator = logits.sum(dim=1) - (1 - mask).sum(1)
        logits = torch.log(logits.div(denominator) + eps)

    else:
        logits = F.log_softmax(logit_mat, dim=1)

    loss = -1 * (logits * label_mat).sum(dim=1).mean()
    return loss

def loss_fn2(logit_mat, label_mat, mask=None):
    if mask is not None:
        eps = 1e-12
        logits = logit_mat * mask
        logits = torch.exp(logits)
        denominator = logits.sum(dim=1) - (1 - mask).sum(1)
        logits = torch.log(logits.div(denominator) + eps)

    else:
        logits = F.log_softmax(logit_mat, dim=1)

    loss = -1 * (logits * label_mat).mean(dim=1).mean()
    return loss
## Need to check l2 loss and FQ_similarity
def N_PQ_loss(labels_similarity, embeddings_x, embeddings_q, n_book, global_proxies=None, reg_lambda=0.002, mask=None):
    """
    N_pair Product Quantization loss
    :return:
    """
    ## L2 loss -> ???
    reg_anchor = embeddings_x.square().sum(dim=1).mean()
    reg_positive = embeddings_q.square().sum(dim=1).mean()
    if global_proxies is not None:
        reg_proxies = global_proxies.square().sum(1).mean()
        l2_loss = (reg_anchor + reg_positive + reg_proxies) * (0.25 * reg_lambda)
    else:
        l2_loss = (reg_anchor + reg_positive) * (0.25 * reg_lambda)

    embeddings_x = F.normalize(embeddings_x, p=2, dim=1)
    embeddings_q = F.normalize(embeddings_q, p=2, dim=1)

    fq_similarity = torch.matmul(embeddings_x, embeddings_q.t()) * n_book

    if global_proxies is not None:
        global_proxies = F.normalize(global_proxies, dim=1)
        fp_similarity = F.linear(embeddings_x, global_proxies) * n_book
        fq_similarity = torch.cat([fq_similarity, fp_similarity], dim=1)

    # Add SoftMax Loss
    #loss = criterion(fq_similarity, labels_similarity)
    loss = loss_fn(fq_similarity, labels_similarity, mask=mask)

    return loss + l2_loss
    #return loss

def N_PQ_loss2(labels_similarity, embeddings_x, embeddings_q, n_book, global_proxies=None, reg_lambda=0.002, mask=None):
    """
    N_pair Product Quantization loss
    :return:
    """
    ## L2 loss -> ???
    reg_anchor = embeddings_x.square().sum(dim=1).mean()
    reg_positive = embeddings_q.square().sum(dim=1).mean()
    if global_proxies is not None:
        reg_proxies = global_proxies.square().sum(1).mean()
        l2_loss = (reg_anchor + reg_positive + reg_proxies) * (0.25 * reg_lambda)
    else:
        l2_loss = (reg_anchor + reg_positive) * (0.25 * reg_lambda)

    embeddings_x = F.normalize(embeddings_x, p=2, dim=1)
    embeddings_q = F.normalize(embeddings_q, p=2, dim=1)

    fq_similarity = torch.matmul(embeddings_x, embeddings_q.t()) * n_book

    if global_proxies is not None:
        global_proxies = F.normalize(global_proxies, dim=1)
        fp_similarity = F.linear(embeddings_x, global_proxies) * n_book
        fq_similarity = torch.cat([fq_similarity, fp_similarity], dim=1)

    # Add SoftMax Loss
    #loss = criterion(fq_similarity, labels_similarity)
    loss = loss_fn2(fq_similarity, labels_similarity, mask=mask)

    return loss + l2_loss
    #return loss

## Need to check l2 loss and FQ_similarity
def PMSE_loss(labels_similarity, embeddings_x, embeddings_q, n_book, global_proxies=None, reg_lambda=0.002, mask=None):
    """
    N_pair Product Quantization loss
    :return:
    """
    loss = 0
    u_batch_size = embeddings_x.size(0)

    ## L2 loss -> ???
    reg_anchor = embeddings_x.square().sum(dim=1).mean()
    reg_positive = embeddings_q.square().sum(dim=1).mean()
    l2_loss = (reg_anchor + reg_positive) * (0.25 * reg_lambda)

    embeddings_x = F.normalize(embeddings_x, p=2, dim=1)
    embeddings_q = F.normalize(embeddings_q, p=2, dim=1)

    fq_similarity = torch.matmul(embeddings_x, embeddings_q.t())

    if global_proxies is not None:
        global_proxies = F.normalize(global_proxies, dim=1)
        fp_similarity = F.linear(embeddings_x, global_proxies)
        fq_similarity = torch.cat([fq_similarity, fp_similarity], dim=1)


        # pos_idxs = (labels_similarity[i] != 0).nonzero().squeeze()
        # loss_ = (1 - fq_similarity[i][pos_idxs]).pow(2)
    pos_mask = (labels_similarity != 0).long()
    loss_ = pos_mask * (1 - fq_similarity).pow(2)

    loss = loss_.sum(1).mean()

    return loss + l2_loss
    #return loss

def PMSE_loss_soft(labels_similarity, embeddings_x, embeddings_q, n_book, global_proxies=None, reg_lambda=0.002, mask=None):
    """
    N_pair Product Quantization loss
    :return:
    """
    loss = 0
    u_batch_size = embeddings_x.size(0)

    ## L2 loss -> ???
    reg_anchor = embeddings_x.square().sum(dim=1).mean()
    reg_positive = embeddings_q.square().sum(dim=1).mean()
    if global_proxies is not None:
        reg_proxies = global_proxies.square().sum(1).mean()
        l2_loss = (reg_anchor + reg_positive + reg_proxies) * (0.25 * reg_lambda)
    else:
        l2_loss = (reg_anchor + reg_positive) * (0.25 * reg_lambda)

    embeddings_x = F.normalize(embeddings_x, p=2, dim=1)
    embeddings_q = F.normalize(embeddings_q, p=2, dim=1)

    fq_similarity = torch.matmul(embeddings_x, embeddings_q.t())

    if global_proxies is not None:
        global_proxies = F.normalize(global_proxies, dim=1)
        fp_similarity = F.linear(embeddings_x, global_proxies)
        fq_similarity = torch.cat([fq_similarity, fp_similarity], dim=1)

    for i in range(embeddings_x.size(0)):
        pos_idxs = (labels_similarity[i] != 0).nonzero().squeeze()
        loss_ = 0
        for pos_idx in pos_idxs:
            loss_ += (fq_similarity[i][i] - fq_similarity[i][pos_idx]).pow(2)
        loss += loss_ / pos_idxs.size(0)

    loss = loss / u_batch_size

    return loss + l2_loss
    #return loss

def Consistency_loss(descriptor, descriptor_prime):
    normed_des = F.normalize(descriptor, dim=1)
    normed_des_p = F.normalize(descriptor_prime, dim=1)

    cos_dis = (1 - torch.matmul(normed_des, normed_des_p.t())).pow(2)
    #consistency_loss = torch.trace(cos_dis)
    mask = torch.eye(cos_dis.size(0)).cuda()
    loss = (cos_dis * mask).sum(1).mean()

    #loss = consistency_loss.mean()
    return loss

def Meta_N_PQ_loss(meta_labels_similarity, embeddings_x, embeddings_q, n_book, reg_lambda=0.002):
    """
   Meta N_pair Product Quantization loss
    :return:
    """
    ## L2 loss -> ???
    reg_anchor = embeddings_x.square().sum(dim=2).mean()
    reg_positive = embeddings_q.square().sum(dim=2).mean()
    l2_loss = (reg_anchor + reg_positive) * (0.25 * reg_lambda)

    embeddings_x = F.normalize(embeddings_x, p=2, dim=2)
    #embeddings_q = F.normalize(embeddings_q, p=2, dim=2)

    fq_similarities =[]
    for i in range(n_book):
        labels_similarity = meta_labels_similarity[i]
        sub_fq_similarity = torch.matmul(embeddings_x[:, i, :], embeddings_q[:, i, :].t()) * n_book


        # Add SoftMax Loss
        #loss = criterion(fq_similarity, labels_similarity)
        #loss = loss_fn(sub_fq_similarity, labels_similarity)
        fq_similarities.append(sub_fq_similarity)

    fq_similarities = torch.stack(fq_similarities)
    logits = F.log_softmax(fq_similarities, dim=2)
    loss = -1 * (logits * meta_labels_similarity).sum(dim=2).mean()
    return loss
    #meta_n_pq_loss = losses.mean()

    #return meta_n_pq_loss + l2_loss
    #return meta_n_pq_loss

def BCE(labels_similarity, m_u, descriptor_u, descriptor_l, nb_positive):

    nb_descriptor_u = descriptor_u.size(0)
    descriptor = torch.cat([descriptor_u, descriptor_l], dim=0)
    similarity = torch.matmul(m_u, descriptor.t())
    #similarity = similarity[:nb_descriptor_u]

    labels_similarity = labels_similarity * nb_positive

    assert labels_similarity.size() == similarity.size()

    loss = F.binary_cross_entropy(similarity, labels_similarity.long())
    return loss

def NM_loss(proxies, k=10):
    normed_proxies = F.normalize(proxies, p=2, dim=1)
    cos_sim = torch.matmul(normed_proxies, normed_proxies.t())
    dis = (1 - cos_sim).square()
    vals, idxs = torch.topk(dis, k=k+1, largest=False, dim=1)
    dis_selected = dis.gather(1, idxs[:, 1:])
    loss = -1 * dis_selected.sum(1).mean()

    return loss

def GenerateLabelMatrix_FD(logits_u, k=5, threshold=0.8):
    #codeword = torch.argmax(logits_u, dim=2)
    normed_feature = F.normalize(logits_u, dim=1)
    cosine_similarity_mat = torch.matmul(normed_feature, normed_feature.t())

    mask = torch.zeros_like(cosine_similarity_mat)
    n_idxs = (cosine_similarity_mat <= threshold).nonzero(as_tuple=True)
    mask[n_idxs[0], n_idxs[1]] = 1

    vals, indices = torch.topk(cosine_similarity_mat, k=k, dim=1)

    mask.scatter_(1, indices, torch.ones_like(cosine_similarity_mat))

    label_matrix = torch.zeros_like(cosine_similarity_mat).scatter_(1, indices, torch.ones_like(cosine_similarity_mat)).float()
    label_matrix /= label_matrix.sum(1).unsqueeze(1)

    return label_matrix, mask

def GenerateLabelMatrix_FD_with_proxies(logits_u, proxies,  k=2, threshold=0.8):
    #codeword = torch.argmax(logits_u, dim=2)
    f_dim = logits_u.size(1)
    proxies = proxies.view(-1, f_dim)
    normed_proxies = F.normalize(proxies, dim=1)
    normed_feature = F.normalize(logits_u, dim=1)
    cosine_similarity_mat_instances = torch.matmul(normed_feature, normed_feature.t())
    cosine_similarity_mat_proxies = torch.matmul(normed_feature, normed_proxies.t())

    cosine_similarity_mat = torch.cat([cosine_similarity_mat_instances, cosine_similarity_mat_proxies], dim=1)

    mask = torch.zeros_like(cosine_similarity_mat)
    n_idxs = (cosine_similarity_mat <= threshold).nonzero(as_tuple=True)
    mask[n_idxs[0], n_idxs[1]] = 1

    vals, indices = torch.topk(cosine_similarity_mat, k=k, dim=1)

    mask.scatter_(1, indices, torch.ones_like(cosine_similarity_mat))

    label_matrix = torch.zeros_like(cosine_similarity_mat).scatter_(1, indices, torch.ones_like(cosine_similarity_mat)).float()
    label_matrix /= label_matrix.sum(1).unsqueeze(1)

    return label_matrix, mask


def GenerateLabelMatrix_FD_thresholding(logits_u, k=2, positive_threshold=0.9, negative_threshold=0.7, return_nb_positive=False):
    #codeword = torch.argmax(logits_u, dim=2)
    f_dim = logits_u.size(1)
    normed_feature = F.normalize(logits_u, dim=1)
    cosine_similarity_mat_instances = torch.matmul(normed_feature, normed_feature.t())

    cosine_similarity_mat = torch.cat([cosine_similarity_mat_instances,], dim=1)

    mask = torch.zeros_like(cosine_similarity_mat)
    n_idxs = (cosine_similarity_mat <= negative_threshold).nonzero(as_tuple=True)
    mask[n_idxs[0], n_idxs[1]] = 1


    #label_matrix = torch.zeros_like(cosine_similarity_mat)
    p_idxs = (cosine_similarity_mat >= positive_threshold).nonzero(as_tuple=True)
    vals, indices = torch.topk(cosine_similarity_mat, k=k, dim=1)

    #mask.scatter_(1, indices, torch.ones_like(cosine_similarity_mat))
    label_matrix = torch.zeros_like(cosine_similarity_mat).scatter_(1, indices, torch.ones_like(cosine_similarity_mat)).float()
    label_matrix[p_idxs[0], p_idxs[1]] = 1
    nb_positive = label_matrix.sum(1).unsqueeze(1)
    label_matrix /= label_matrix.sum(1).unsqueeze(1)

    if return_nb_positive:
        return label_matrix, mask, nb_positive
    else:
        return label_matrix, mask

def GenerateLabelMatrix_FD_with_proxies_thresholding(logits_u,  proxies, k=2, positive_threshold=0.9, negative_threshold=0.7, return_nb_positive=False, normalize=True):
    #codeword = torch.argmax(logits_u, dim=2)
    f_dim = logits_u.size(1)
    proxies = proxies.view(-1, f_dim)
    normed_proxies = F.normalize(proxies, dim=1)
    normed_feature = F.normalize(logits_u, dim=1)
    cosine_similarity_mat_instances = torch.matmul(normed_feature, normed_feature.t())
    cosine_similarity_mat_proxies = torch.matmul(normed_feature, normed_proxies.t())

    cosine_similarity_mat = torch.cat([cosine_similarity_mat_instances, cosine_similarity_mat_proxies], dim=1)

    mask = torch.zeros_like(cosine_similarity_mat)
    n_idxs = (cosine_similarity_mat <= negative_threshold).nonzero(as_tuple=True)
    mask[n_idxs[0], n_idxs[1]] = 1


    #label_matrix = torch.zeros_like(cosine_similarity_mat)
    p_idxs = (cosine_similarity_mat >= positive_threshold).nonzero(as_tuple=True)
    vals, indices = torch.topk(cosine_similarity_mat, k=k, dim=1)

    #mask.scatter_(1, indices, torch.ones_like(cosine_similarity_mat))
    label_matrix = torch.zeros_like(cosine_similarity_mat).scatter_(1, indices, torch.ones_like(cosine_similarity_mat)).float()
    label_matrix[p_idxs[0], p_idxs[1]] = 1
    nb_positive = label_matrix.sum(1).unsqueeze(1)


    if normalize:
        label_matrix /= label_matrix.sum(1).unsqueeze(1)

        if return_nb_positive:
            return label_matrix, mask, nb_positive
        else:
            return label_matrix, mask
    else:
        if return_nb_positive:
            return label_matrix, mask, nb_positive
        else:
            return label_matrix, mask

def GenerateLabelMatrix_with_thresholding(logit1, logit2, positive_threshold=0.9, return_nb_positive=False, normalize=False):
    #codeword = torch.argmax(logits_u, dim=2)
    f_dim = logit1.size(1)
    logit2 = logit2.view(-1, f_dim)
    normed_logit1 = F.normalize(logit1, dim=1)
    normed_logit2 = F.normalize(logit2, dim=1)
    cosine_similarity_mat = torch.matmul(normed_logit1, normed_logit2.t())

    #label_matrix = torch.zeros_like(cosine_similarity_mat)
    p_idxs = (cosine_similarity_mat >= positive_threshold).nonzero(as_tuple=True)

    #mask.scatter_(1, indices, torch.ones_like(cosine_similarity_mat))
    label_matrix = torch.zeros_like(cosine_similarity_mat).float()
    label_matrix[p_idxs[0], p_idxs[1]] = 1
    nb_positive = label_matrix.sum(1).unsqueeze(1)


    if normalize:
        label_matrix /= label_matrix.sum(1).unsqueeze(1)

        if return_nb_positive:
            return label_matrix,  nb_positive
        else:
            return label_matrix
    else:
        if return_nb_positive:
            return label_matrix.float(),  nb_positive
        else:
            return label_matrix


def GenerateLabelMatrix_FD_with_proxies_thresholding_stochastic(logits_u,  proxies, k=2, positive_threshold=0.9, negative_threshold=0.7, return_nb_positive=False):
    #codeword = torch.argmax(logits_u, dim=2)
    f_dim = logits_u.size(1)
    proxies = proxies.view(-1, f_dim)
    normed_proxies = F.normalize(proxies, dim=1)
    normed_feature = F.normalize(logits_u, dim=1)
    cosine_similarity_mat_instances = torch.matmul(normed_feature, normed_feature.t())
    cosine_similarity_mat_proxies = torch.matmul(normed_feature, normed_proxies.t())

    cosine_similarity_mat = torch.cat([cosine_similarity_mat_instances, cosine_similarity_mat_proxies], dim=1)

    mask = torch.zeros_like(cosine_similarity_mat)
    n_idxs = (cosine_similarity_mat <= negative_threshold).nonzero(as_tuple=True)
    mask[n_idxs[0], n_idxs[1]] = 1


    #label_matrix = torch.zeros_like(cosine_similarity_mat)
    p_idxs = (cosine_similarity_mat >= positive_threshold).nonzero(as_tuple=True)
    vals, indices = torch.topk(cosine_similarity_mat, k=k, dim=1)

    idxs = torch.randint(1, indices.size(1), (indices.size(0),))
    indices[range(indices.size(0)), idxs].unsqueeze(1)

    #mask.scatter_(1, indices, torch.ones_like(cosine_similarity_mat))
    label_matrix = torch.zeros_like(cosine_similarity_mat).scatter_(1, indices, torch.ones_like(cosine_similarity_mat)).float()
    label_matrix[p_idxs[0], p_idxs[1]] = 1
    nb_positive = label_matrix.sum(1).unsqueeze(1)
    label_matrix /= label_matrix.sum(1).unsqueeze(1)

    if return_nb_positive:
        return label_matrix, mask, nb_positive
    else:
        return label_matrix, mask

def GenerateLabelMatrix_FD_with_proxies_thresholding_last(logits_u,  proxies, k=2, positive_threshold=0.9, negative_threshold=0.7, return_nb_positive=False):
    #codeword = torch.argmax(logits_u, dim=2)
    f_dim = logits_u.size(1)
    proxies = proxies.view(-1, f_dim)
    normed_proxies = F.normalize(proxies, dim=1)
    normed_feature = F.normalize(logits_u, dim=1)
    cosine_similarity_mat_instances = torch.matmul(normed_feature, normed_feature.t())
    cosine_similarity_mat_proxies = torch.matmul(normed_feature, normed_proxies.t())

    cosine_similarity_mat = torch.cat([cosine_similarity_mat_instances, cosine_similarity_mat_proxies], dim=1)

    mask = torch.zeros_like(cosine_similarity_mat)
    n_idxs = (cosine_similarity_mat <= negative_threshold).nonzero(as_tuple=True)
    mask[n_idxs[0], n_idxs[1]] = 1


    #label_matrix = torch.zeros_like(cosine_similarity_mat)
    p_idxs = (cosine_similarity_mat >= positive_threshold).nonzero(as_tuple=True)
    vals, indices = torch.topk(cosine_similarity_mat, k=k, dim=1)
    indices = indices[:, -2:-1]

    #mask.scatter_(1, indices, torch.ones_like(cosine_similarity_mat))
    label_matrix = torch.zeros_like(cosine_similarity_mat).scatter_(1, indices, torch.ones_like(cosine_similarity_mat)).float()
    label_matrix[p_idxs[0], p_idxs[1]] = 1
    nb_positive = label_matrix.sum(1).unsqueeze(1)
    label_matrix /= label_matrix.sum(1).unsqueeze(1)

    if return_nb_positive:
        return label_matrix, mask, nb_positive
    else:
        return label_matrix, mask




