import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
from pytorch_metric_learning import miners, losses

def binarize(T, nb_classes):
    T = T.cpu().numpy()
    import sklearn.preprocessing
    T = sklearn.preprocessing.label_binarize(
        T, classes = range(0, nb_classes)
    )
    T = torch.FloatTensor(T).cuda()
    return T

def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)
    return output

class Proxy_Anchor(torch.nn.Module):
    def __init__(self, nb_classes, sz_embed, mrg = 0.1, alpha = 32):
        torch.nn.Module.__init__(self)
        # Proxy Anchor Initialization
        self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda())
        nn.init.kaiming_normal_(self.proxies, mode='fan_out')

        self.nb_classes = nb_classes
        self.sz_embed = sz_embed
        self.mrg = mrg
        self.alpha = alpha
        
    def forward(self, X, T):
        P = self.proxies

        cos = F.linear(l2_norm(X), l2_norm(P))  # Calcluate cosine similarity
        P_one_hot = binarize(T = T, nb_classes = self.nb_classes)
        N_one_hot = 1 - P_one_hot
    
        pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
        neg_exp = torch.exp(self.alpha * (cos + self.mrg))

        with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1)   # The set of positive proxies of data in the batch
        num_valid_proxies = len(with_pos_proxies)   # The number of positive proxies
        
        P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0) 
        N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)
        
        pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
        neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
        loss = pos_term + neg_term     
        
        return loss


class NormSoftmaxLoss(nn.Module):
    """
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_classes,
                 sz_embed,
                 temperature=0.05):
        super(NormSoftmaxLoss, self).__init__()
        self.nb_classes = nb_classes

        self.proxies = torch.nn.Parameter(torch.Tensor(nb_classes, sz_embed).cuda())
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        nn.init.kaiming_uniform_(self.proxies, a=math.sqrt(5))

        self.temperature = temperature
        #self.loss_fn = nn.CrossEntropyLoss()

    def loss_fn(self, logits, intance_targets):
        if len(intance_targets.size()) > 1:
            intance_targets = intance_targets.float() / intance_targets.sum(dim=1).unsqueeze(1)
            loss = -1 * F.log_softmax(logits, dim=1) * intance_targets
            return loss.sum(1).mean()
        else:
            #onehot = torch.eye(self.nb_classes)[intance_targets].cuda()
            return F.cross_entropy(logits, intance_targets)
            #loss = -1 * F.log_softmax(logits, dim=1) * onehot
            #return loss.sum(1).mean()

    def forward(self, embeddings, instance_targets):
        norm_weight = nn.functional.normalize(self.proxies, dim=1)

        prediction_logits = nn.functional.linear(embeddings, norm_weight)

        loss = self.loss_fn(prediction_logits / self.temperature, instance_targets)
        return loss

class NormSoftmaxLossDisjoint(nn.Module):
    """
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_classes,
                 num_partitionings,
                 sz_embed,
                 temperature=0.0625):
        super(NormSoftmaxLossDisjoint, self).__init__()
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed

        self.proxies = torch.nn.Parameter(torch.Tensor(nb_classes, num_partitionings * sz_embed).cuda())
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        #nn.init.kaiming_uniform_(self.proxies, a=math.sqrt(5))
        nn.init.xavier_uniform_(self.proxies)

        self.temperature = temperature
        self.loss_fn = nn.CrossEntropyLoss()

    # def forward(self, embeddings, instance_targets):
    #     proxies = self.proxies.view(-1, self.num_partitionings, self.sz_embed)
    #     embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
    #     norm_weight = nn.functional.normalize(proxies, dim=2)
    #
    #     sub_losses = []
    #     for i in range(self.num_partitionings):
    #         sub_logits = nn.functional.linear(embeddings[:, i, :], norm_weight[:, i, :])
    #         sub_loss = self.loss_fn(sub_logits / self.temperature, instance_targets)
    #         sub_losses.append(sub_loss)
    #
    #     sub_losses = torch.stack(sub_losses)
    #     return sub_losses.mean()
    def forward(self, embeddings, instance_targets):
        proxies = self.proxies.view(-1, self.num_partitionings, self.sz_embed)
        embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
        norm_weight = nn.functional.normalize(proxies, dim=2)

        sub_losses = []
        logits = []
        for i in range(self.num_partitionings):
            sub_logits = nn.functional.linear(embeddings[:, i, :], norm_weight[:, i, :]).unsqueeze(-1)
            logits.append(sub_logits)
            #sub_loss = self.loss_fn(sub_logits / self.temperature, instance_targets)
            #sub_losses.append(sub_loss)
        logits = torch.cat(logits, dim=2).mean(2)
        loss = self.loss_fn(logits / self.temperature, instance_targets)
        return loss


class CombiNormSoftmaxLoss(nn.Module):
    """
    Combinatorial metric learning
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 proxies,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 temperature=0.05):
        super(CombiNormSoftmaxLoss, self).__init__()

        self.num_partitions = num_partitions
        self.num_partitionings = num_partitionings
        self.num_classes = nb_classes
        self.proxies = proxies
        self.temperature = temperature
        self.loss_fn = nn.NLLLoss()

        self.register_buffer('partitionings', -torch.ones(num_partitionings, nb_classes).long())

    def set_partitionings(self, partitionings_map):
        self.partitionings.copy_(torch.LongTensor(partitionings_map).t())
        arange = torch.arange(self.num_partitionings).view(-1, 1).type_as(self.partitionings)
        #arange를 더해준다.? -> 01110, 23332
        self.partitionings.add_(arange * self.num_partitions)

    def rescale_grad(self):
        self.proxies.grad.mul_(self.num_partitionings)

    def forward(self, embeddings, instance_targets, return_meta_dist=False):
        assert self.partitionings.sum() > 0, 'Partitionings is never given to the module.'
        norm_weight = nn.functional.normalize(self.proxies, dim=2)

        #prediction_logits = nn.functional.linear(embeddings, norm_weight)
        prediction_logits = torch.einsum('ab,bcd->acd', (embeddings, norm_weight.permute(2, 0, 1)))

        prediction_logits = prediction_logits.view(-1, self.num_partitionings, self.num_partitions)
        prediction_logits = F.log_softmax(prediction_logits / self.temperature, dim=2)

        if return_meta_dist:
            return prediction_logits

        prediction_logits = prediction_logits.view(-1, self.num_partitionings * self.num_partitions)
        output = prediction_logits.index_select(1, self.partitionings.view(-1))
        output = output.view(-1, self.num_partitionings, self.num_classes)

        output = output.sum(1)

        loss = self.loss_fn(output, instance_targets) / self.num_partitionings

        return loss

class CombiNormSoftmaxLossDisjoint(nn.Module):
    """
    Combinatorial metric learning
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 proxies,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 temperature=0.05):
        super(CombiNormSoftmaxLossDisjoint, self).__init__()

        self.num_partitions = num_partitions
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed
        self.num_classes = nb_classes
        self.proxies = proxies
        self.temperature = temperature
        self.loss_fn = nn.NLLLoss()

        self.register_buffer('partitionings', -torch.ones(num_partitionings, nb_classes).long())

    def set_partitionings(self, partitionings_map):
        self.partitionings.copy_(torch.LongTensor(partitionings_map).t())
        arange = torch.arange(self.num_partitionings).view(-1, 1).type_as(self.partitionings)
        #arange를 더해준다.? -> 01110, 23332
        self.partitionings.add_(arange * self.num_partitions)

    def rescale_grad(self):
        self.proxies.grad.mul_(self.num_partitionings)

    def forward(self, embeddings, instance_targets, return_meta_dist=False):
        assert self.partitionings.sum() > 0, 'Partitionings is never given to the module.'
        norm_weight = nn.functional.normalize(self.proxies, dim=2)
        embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
        #prediction_logits = nn.functional.linear(embeddings, norm_weight)
        sub_meta_logits_lst = []
        for i in range(self.num_partitionings):
            sub_meta_logits = F.linear(embeddings[:, i, :], norm_weight[i, : , :]).unsqueeze(1)
            sub_meta_logits_lst.append(sub_meta_logits)
        prediction_logits = torch.cat(sub_meta_logits_lst, dim=1)
        prediction_logits = prediction_logits.view(-1, self.num_partitionings, self.num_partitions)
        prediction_logits = F.log_softmax(prediction_logits / self.temperature, dim=2)

        if return_meta_dist:
            return prediction_logits

        prediction_logits = prediction_logits.view(-1, self.num_partitionings * self.num_partitions)
        output = prediction_logits.index_select(1, self.partitionings.view(-1))
        output = output.view(-1, self.num_partitionings, self.num_classes)

        output = output.sum(1)

        loss = self.loss_fn(output, instance_targets) / self.num_partitionings

        return loss

class CombiNormSoftmaxLossWithLinear(nn.Module):
    """
    Combinatorial metric learning with additional linear classifier on original class space
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 meta_temperature=0.05,
                 orig_temperature=0.01,
                 alpha=1.0):
        super(CombiNormSoftmaxLossWithLinear, self).__init__()
        self.proxies = torch.nn.Parameter(torch.Tensor(num_partitionings, num_partitions, sz_embed))
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        stdv = 1. / math.sqrt(self.proxies.size(2))
        self.proxies.data.uniform_(-stdv, stdv)

        self.normsoftmax = NormSoftmaxLoss(nb_classes=nb_classes, sz_embed=sz_embed, temperature=orig_temperature)
        self.combinatorial_classifiers = CombiNormSoftmaxLoss(proxies=self.proxies, nb_classes=nb_classes, sz_embed=sz_embed, num_partitions=num_partitions,
                                                              num_partitionings=num_partitionings, temperature=meta_temperature)
        self.alpha = alpha

    def forward(self, embeddings, instance_targets, return_meta_dist=False):
        orig_loss = self.normsoftmax(embeddings, instance_targets)
        meta_loss = self.combinatorial_classifiers(embeddings, instance_targets, return_meta_dist=return_meta_dist)
        # loss = meta_loss + self.alpha * orig_loss
        # return loss
        return (meta_loss, orig_loss)

class CombiNormSoftmaxLossWithLinearDisjoint(nn.Module):
    """
    Combinatorial metric learning with additional linear classifier on original class space
    L2 normalize weights and apply temperature scaling on logits.
    """
    def __init__(self,
                 nb_classes,
                 sz_embed,
                 num_partitions,
                 num_partitionings,
                 meta_temperature=0.05,
                 orig_temperature=0.01,
                 alpha=1.0,
                 is_norm=True):
        super(CombiNormSoftmaxLossWithLinearDisjoint, self).__init__()
        self.is_norm = is_norm
        self.num_partitionings = num_partitionings
        self.sz_embed = sz_embed
        self.proxies = torch.nn.Parameter(torch.Tensor(num_partitionings, num_partitions, sz_embed))
        self.nb_classes = nb_classes
        # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
        stdv = 1. / math.sqrt(self.proxies.size(2))
        #self.proxies.data.uniform_(-stdv, stdv)
        nn.init.xavier_uniform_(self.proxies)

        self.normsoftmax = NormSoftmaxLossDisjoint(nb_classes=nb_classes, num_partitionings=num_partitionings, sz_embed=sz_embed, temperature=orig_temperature)
        self.combinatorial_classifiers = CombiNormSoftmaxLossDisjoint(proxies=self.proxies, nb_classes=nb_classes, sz_embed=sz_embed, num_partitions=num_partitions,
                                                              num_partitionings=num_partitionings, temperature=meta_temperature)
        self.alpha = alpha

    def forward(self, embeddings, instance_targets, return_meta_dist=False):
        embeddings = embeddings.view(-1, self.num_partitionings, self.sz_embed)
        if self.is_norm:
            embeddings = F.normalize(embeddings, p=2, dim=2)

        orig_loss = self.normsoftmax(embeddings, instance_targets)
        meta_loss = self.combinatorial_classifiers(embeddings, instance_targets, return_meta_dist=return_meta_dist)
        # loss = meta_loss + self.alpha * orig_loss
        # return loss
        return (meta_loss, orig_loss)

    def SoftAssignment(self, num_partitionings):
        x = self.proxies
        z = self.normsoftmax.proxies
        z = z.view(self.nb_classes, num_partitionings, self.sz_embed)

        for i in range(num_partitionings):
            size_x = x[i].size(0)
            size_z = z[:, i, :].size(0)
            xx = F.normalize(x[i], p=2, dim=1).unsqueeze(-1)
            xx = xx.repeat(1, 1, size_z)

            zz = F.normalize(z[:, i, :], p=2, dim=1).unsqueeze(-1)
            zz = zz.repeat(1, 1, size_x)
            zz = zz.permute(2, 1, 0)

            diff = 1 - (xx * zz).sum(dim=1)
            softmax_diff = F.softmax(diff * (-1 * 20), dim=1)

            if i == 0:
                soft_des_tmp = torch.mm(softmax_diff, z[:, i, :])
                descriptor = soft_des_tmp
            else:
                soft_des_tmp = torch.mm(softmax_diff, z[:, i, :])
                descriptor = torch.cat([descriptor, soft_des_tmp], dim=1)
        descriptor = descriptor.view(-1, num_partitionings, self.sz_embed)
        descriptor = descriptor.permute(1, 0, 2)
        self.proxies.data = descriptor



# class CombiNormSoftmaxLossWithLinear(nn.Module):
#     """
#     Combinatorial metric learning with additional linear classifier on original class space
#     L2 normalize weights and apply temperature scaling on logits.
#     """
#     def __init__(self,
#                  nb_classes,
#                  sz_embed,
#                  num_partitions,
#                  num_partitionings,
#                  meta_temperature=0.05,
#                  orig_temperature=0.01):
#         super(CombiNormSoftmaxLossWithLinear, self).__init__()
#
#         self.num_partitions = num_partitions
#         self.num_partitionings = num_partitionings
#         self.num_classes = nb_classes
#
#         # Meta proxies for combinatorial learning
#         self.meta_proxies = torch.nn.Parameter(torch.Tensor(num_partitions * num_partitionings, sz_embed).cuda())
#         # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129)
#         stdv = 1. / math.sqrt(self.meta_proxies.size(1))
#         self.meta_proxies.data.uniform_(-stdv, stdv)
#
#         self.meta_temperature = meta_temperature
#         self.meta_loss_fn = nn.NLLLoss()
#
#         self.register_buffer('partitionings', -torch.ones(num_partitionings, nb_classes).long())
#
#         # additional proxy on original class space
#         self.orig_proxies = torch.nn.Parameter(torch.Tensor(nb_classes, sz_embed).cuda())
#         stdv = 1. / math.sqrt(self.orig_proxies.size(1))
#         self.orig_proxies.data.uniform_(-stdv, stdv)
#
#         self.orig_temperature = orig_temperature
#         self.orig_loss_fn = nn.CrossEntropyLoss()
#
#     def set_partitionings(self, partitionings_map):
#         self.partitionings.copy_(torch.LongTensor(partitionings_map).t())
#         arange = torch.arange(self.num_partitionings).view(-1, 1).type_as(self.partitionings)
#         #arange를 더해준다.? -> 01110, 23332
#         self.partitionings.add_(arange * self.num_partitions)
#
#     def rescale_grad(self, total_num_partitionings=None):
#         # for params in self.proxies.parameters():
#         #     if total_num_partitionings is None:
#         #         params.grad.mul_(self.num_partitionings)
#         #     else:
#         #         params.grad.mul_(total_num_partitionings)
#         self.proxies.grad.mul_(self.num_partitionings)
#
#     def forward(self, embeddings, instance_targets, return_meta_dist=False):
#         assert self.partitionings.sum() > 0, 'Partitionings is never given to the module.'
#
#         # normalize weight
#         norm_meta_weight = nn.functional.normalize(self.meta_proxies, dim=1)
#         norm_orig_weight = nn.functional.normalize(self.orig_proxies, dim=1)
#
#         # combinatorial logit with meta_weights
#         prediction_logits = nn.functional.linear(embeddings, norm_meta_weight)
#         prediction_logits = prediction_logits.view(-1, self.num_partitionings, self.num_partitions)
#         prediction_logits = F.log_softmax(prediction_logits / self.temperature, dim=2)
#
#         if return_meta_dist:
#             return prediction_logits
#
#         prediction_logits = prediction_logits.view(-1, self.num_partitionings * self.num_partitions)
#         output = prediction_logits.index_select(1, self.partitionings.view(-1))
#         output = output.view(-1, self.num_partitionings, self.num_classes)
#
#         meta_output = output.sum(1)
#
#         meta_loss = self.loss_fn(meta_output, instance_targets) / self.num_partitionings
#
#         # Orignal proxy loss
#
#
#         return loss



# We use PyTorch Metric Learning library for the following codes.
# Please refer to "https://github.com/KevinMusgrave/pytorch-metric-learning" for details.
class Proxy_NCA(torch.nn.Module):
    def __init__(self, nb_classes, sz_embed, scale=32):
        super(Proxy_NCA, self).__init__()
        self.nb_classes = nb_classes
        self.sz_embed = sz_embed
        self.scale = scale
        self.loss_func = losses.ProxyNCALoss(num_classes = self.nb_classes, embedding_size = self.sz_embed, softmax_scale = self.scale).cuda()

    def forward(self, embeddings, labels):
        loss = self.loss_func(embeddings, labels)
        return loss
    
class MultiSimilarityLoss(torch.nn.Module):
    def __init__(self, ):
        super(MultiSimilarityLoss, self).__init__()
        self.thresh = 0.5
        self.epsilon = 0.1
        self.scale_pos = 2
        self.scale_neg = 50
        
        self.miner = miners.MultiSimilarityMiner(epsilon=self.epsilon)
        self.loss_func = losses.MultiSimilarityLoss(self.scale_pos, self.scale_neg, self.thresh)
        
    def forward(self, embeddings, labels):
        hard_pairs = self.miner(embeddings, labels)
        loss = self.loss_func(embeddings, labels, hard_pairs)
        return loss
    
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5, **kwargs):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.loss_func = losses.ContrastiveLoss(neg_margin=self.margin) 
        
    def forward(self, embeddings, labels):
        loss = self.loss_func(embeddings, labels)
        return loss
    
class TripletLoss(nn.Module):
    def __init__(self, margin=0.1, **kwargs):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.miner = miners.TripletMarginMiner(margin, type_of_triplets = 'semihard')
        self.loss_func = losses.TripletMarginLoss(margin = self.margin)
        
    def forward(self, embeddings, labels):
        hard_pairs = self.miner(embeddings, labels)
        loss = self.loss_func(embeddings, labels, hard_pairs)
        return loss
    
class NPairLoss(nn.Module):
    def __init__(self, l2_reg=0):
        super(NPairLoss, self).__init__()
        self.l2_reg = l2_reg
        self.loss_func = losses.NPairsLoss(l2_reg_weight=self.l2_reg, normalize_embeddings=False)
        
    def forward(self, embeddings, labels):
        loss = self.loss_func(embeddings, labels)
        return loss




def PseudoCodewordLoss(x ,y, reduction='mean'):
    return F.nll_loss(x.transpose(1, 2), y, reduction=reduction)

def PseudoCodewordLoss_l2(x,y, reduction='mean'):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)

    return (2 - 2 * (x * y).sum(dim=-1)).mean()

def PseudoCodewordLoss_KLDIV(x, y):
    return nn.KLDivLoss(x, y)


def PseudoCodewordLossSoft(x,y, reduction='mean'):
    if reduction == 'mean':
        loss = - (x * y).mean()
    elif reduction == 'none':
        loss = - (x * y)
    else:
        raise NameError('reduction should be mean or none')
    return loss

def EntropyLoss(descriptor, cls_proxies, temperature=0.2):
    normed_descriptor = F.normalize(descriptor, dim=1)
    normed_weight = F.normalize(cls_proxies, dim=1)

    prediction_logits = nn.functional.linear(normed_descriptor, normed_weight)
    entropy = -1 * F.softmax(prediction_logits / temperature, dim=1) * F.log_softmax(prediction_logits / temperature, dim=1)
    return entropy.sum(1).mean()
