# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
# The MIT License (MIT)
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details

# src/utils/loss.py

from torch.nn import DataParallel
from torch import autograd
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
import numpy as np

from utils.style_ops import conv2d_gradfix
import utils.ops as ops

# ========================================================
from geomloss import SamplesLoss
from torch.autograd import Function

from torch.autograd import Variable
import math
# ========================================================


class GatherLayer(torch.autograd.Function):
    """
    This file is copied from
    https://github.com/open-mmlab/OpenSelfSup/blob/master/openselfsup/models/utils/gather_layer.py
    Gather tensors from all process, supporting backward propagation
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        input, = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out


class CrossEntropyLoss(torch.nn.Module):
    def __init__(self):
        super(CrossEntropyLoss, self).__init__()

        self.ce_loss = torch.nn.CrossEntropyLoss()

    def forward(self, cls_output, label, **_):

        return self.ce_loss(cls_output, label).mean()

class AdvCrossEntropyLoss(torch.nn.Module):
    def __init__(self, cls_adv_temp):
        super(AdvCrossEntropyLoss, self).__init__()
        self.t = cls_adv_temp
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def forward(self, cls_output, label, **_):

        cls_output = cls_output / self.t
        return self.ce_loss(cls_output, label).mean()

class MiCrossEntropyLoss(torch.nn.Module):
    def __init__(self):
        super(MiCrossEntropyLoss, self).__init__()
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def forward(self, mi_cls_output, label, **_):
        return self.ce_loss(mi_cls_output, label).mean()


class ConditionalContrastiveLoss(torch.nn.Module):
    def __init__(self, num_classes, temperature, master_rank, DDP):
        super(ConditionalContrastiveLoss, self).__init__()
        self.num_classes = num_classes
        self.temperature = temperature
        self.master_rank = master_rank
        self.DDP = DDP
        self.calculate_similarity_matrix = self._calculate_similarity_matrix()
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)

    def _make_neg_removal_mask(self, labels):
        labels = labels.detach().cpu().numpy()
        n_samples = labels.shape[0]
        mask_multi, target = np.zeros([self.num_classes, n_samples]), 1.0
        for c in range(self.num_classes):
            c_indices = np.where(labels == c)
            mask_multi[c, c_indices] = target
        return torch.tensor(mask_multi).type(torch.long).to(self.master_rank)

    def _calculate_similarity_matrix(self):
        return self._cosine_simililarity_matrix

    def _remove_diag(self, M):
        h, w = M.shape
        assert h == w, "h and w should be same"
        mask = np.ones((h, w)) - np.eye(h)
        mask = torch.from_numpy(mask)
        mask = (mask).type(torch.bool).to(self.master_rank)
        return M[mask].view(h, -1)

    def _cosine_simililarity_matrix(self, x, y):
        v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, embed, proxy, label, **_):
        if self.DDP:
            embed = torch.cat(GatherLayer.apply(embed), dim=0)
            proxy = torch.cat(GatherLayer.apply(proxy), dim=0)
            label = torch.cat(GatherLayer.apply(label), dim=0)

        sim_matrix = self.calculate_similarity_matrix(embed, embed)
        sim_matrix = torch.exp(self._remove_diag(sim_matrix) / self.temperature)
        neg_removal_mask = self._remove_diag(self._make_neg_removal_mask(label)[label])
        sim_pos_only = neg_removal_mask * sim_matrix

        emb2proxy = torch.exp(self.cosine_similarity(embed, proxy) / self.temperature)

        numerator = emb2proxy + sim_pos_only.sum(dim=1)
        denomerator = torch.cat([torch.unsqueeze(emb2proxy, dim=1), sim_matrix], dim=1).sum(dim=1)
        return -torch.log(numerator / denomerator).mean()


class MiConditionalContrastiveLoss(torch.nn.Module):
    def __init__(self, num_classes, temperature, master_rank, DDP):
        super(MiConditionalContrastiveLoss, self).__init__()
        self.num_classes = num_classes
        self.temperature = temperature
        self.master_rank = master_rank
        self.DDP = DDP
        self.calculate_similarity_matrix = self._calculate_similarity_matrix()
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)

    def _make_neg_removal_mask(self, labels):
        labels = labels.detach().cpu().numpy()
        n_samples = labels.shape[0]
        mask_multi, target = np.zeros([self.num_classes, n_samples]), 1.0
        for c in range(self.num_classes):
            c_indices = np.where(labels == c)
            mask_multi[c, c_indices] = target
        return torch.tensor(mask_multi).type(torch.long).to(self.master_rank)

    def _calculate_similarity_matrix(self):
        return self._cosine_simililarity_matrix

    def _remove_diag(self, M):
        h, w = M.shape
        assert h == w, "h and w should be same"
        mask = np.ones((h, w)) - np.eye(h)
        mask = torch.from_numpy(mask)
        mask = (mask).type(torch.bool).to(self.master_rank)
        return M[mask].view(h, -1)

    def _cosine_simililarity_matrix(self, x, y):
        v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, mi_embed, mi_proxy, label, **_):
        if self.DDP:
            mi_embed = torch.cat(GatherLayer.apply(mi_embed), dim=0)
            mi_proxy = torch.cat(GatherLayer.apply(mi_proxy), dim=0)
            label = torch.cat(GatherLayer.apply(label), dim=0)

        sim_matrix = self.calculate_similarity_matrix(mi_embed, mi_embed)
        sim_matrix = torch.exp(self._remove_diag(sim_matrix) / self.temperature)
        neg_removal_mask = self._remove_diag(self._make_neg_removal_mask(label)[label])
        sim_pos_only = neg_removal_mask * sim_matrix

        emb2proxy = torch.exp(self.cosine_similarity(mi_embed, mi_proxy) / self.temperature)

        numerator = emb2proxy + sim_pos_only.sum(dim=1)
        denomerator = torch.cat([torch.unsqueeze(emb2proxy, dim=1), sim_matrix], dim=1).sum(dim=1)
        return -torch.log(numerator / denomerator).mean()


class Data2DataCrossEntropyLoss(torch.nn.Module):
    def __init__(self, num_classes, temperature, m_p, master_rank, DDP):
        super(Data2DataCrossEntropyLoss, self).__init__()
        self.num_classes = num_classes
        self.temperature = temperature
        self.m_p = m_p
        self.master_rank = master_rank
        self.DDP = DDP
        self.calculate_similarity_matrix = self._calculate_similarity_matrix()
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)

    def _calculate_similarity_matrix(self):
        return self._cosine_simililarity_matrix

    def _cosine_simililarity_matrix(self, x, y):
        v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def make_index_matrix(self, labels):
        labels = labels.detach().cpu().numpy()
        num_samples = labels.shape[0]
        mask_multi, target = np.ones([self.num_classes, num_samples]), 0.0

        for c in range(self.num_classes):
            c_indices = np.where(labels==c)
            mask_multi[c, c_indices] = target
        return torch.tensor(mask_multi).type(torch.long).to(self.master_rank)

    def remove_diag(self, M):
        h, w = M.shape
        assert h==w, "h and w should be same"
        mask = np.ones((h, w)) - np.eye(h)
        mask = torch.from_numpy(mask)
        mask = (mask).type(torch.bool).to(self.master_rank)
        return M[mask].view(h, -1)

    def forward(self, embed, proxy, label, **_):
        # If train a GAN throuh DDP, gather all data on the master rank
        if self.DDP:
            embed = torch.cat(GatherLayer.apply(embed), dim=0)
            proxy = torch.cat(GatherLayer.apply(proxy), dim=0)
            label = torch.cat(GatherLayer.apply(label), dim=0)

        # calculate similarities between sample embeddings
        sim_matrix = self.calculate_similarity_matrix(embed, embed) + self.m_p - 1
        # remove diagonal terms
        sim_matrix = self.remove_diag(sim_matrix/self.temperature)
        # for numerical stability
        sim_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
        sim_matrix = F.relu(sim_matrix) - sim_max.detach()

        # calculate similarities between sample embeddings and the corresponding proxies
        smp2proxy = self.cosine_similarity(embed, proxy)
        # make false negative removal
        removal_fn = self.remove_diag(self.make_index_matrix(label)[label])
        # apply the negative removal to the similarity matrix
        improved_sim_matrix = removal_fn*torch.exp(sim_matrix)

        # compute positive attraction term
        pos_attr = F.relu((self.m_p - smp2proxy)/self.temperature)
        # compute negative repulsion term
        neg_repul = torch.log(torch.exp(-pos_attr) + improved_sim_matrix.sum(dim=1))
        # compute data to data cross-entropy criterion
        criterion = pos_attr + neg_repul
        return criterion.mean()


class MiData2DataCrossEntropyLoss(torch.nn.Module):
    def __init__(self, num_classes, temperature, m_p, master_rank, DDP):
        super(MiData2DataCrossEntropyLoss, self).__init__()
        self.num_classes = num_classes
        self.temperature = temperature
        self.m_p = m_p
        self.master_rank = master_rank
        self.DDP = DDP
        self.calculate_similarity_matrix = self._calculate_similarity_matrix()
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)

    def _calculate_similarity_matrix(self):
        return self._cosine_simililarity_matrix

    def _cosine_simililarity_matrix(self, x, y):
        v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def make_index_matrix(self, labels):
        labels = labels.detach().cpu().numpy()
        num_samples = labels.shape[0]
        mask_multi, target = np.ones([self.num_classes, num_samples]), 0.0

        for c in range(self.num_classes):
            c_indices = np.where(labels==c)
            mask_multi[c, c_indices] = target
        return torch.tensor(mask_multi).type(torch.long).to(self.master_rank)

    def remove_diag(self, M):
        h, w = M.shape
        assert h==w, "h and w should be same"
        mask = np.ones((h, w)) - np.eye(h)
        mask = torch.from_numpy(mask)
        mask = (mask).type(torch.bool).to(self.master_rank)
        return M[mask].view(h, -1)

    def forward(self, mi_embed, mi_proxy, label, **_):
        # If train a GAN throuh DDP, gather all data on the master rank
        if self.DDP:
            mi_embed = torch.cat(GatherLayer.apply(mi_embed), dim=0)
            mi_proxy = torch.cat(GatherLayer.apply(mi_proxy), dim=0)
            label = torch.cat(GatherLayer.apply(label), dim=0)

        # calculate similarities between sample embeddings
        sim_matrix = self.calculate_similarity_matrix(mi_embed, mi_embed) + self.m_p - 1
        # remove diagonal terms
        sim_matrix = self.remove_diag(sim_matrix/self.temperature)
        # for numerical stability
        sim_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
        sim_matrix = F.relu(sim_matrix) - sim_max.detach()

        # calculate similarities between sample embeddings and the corresponding proxies
        smp2proxy = self.cosine_similarity(mi_embed, mi_proxy)
        # make false negative removal
        removal_fn = self.remove_diag(self.make_index_matrix(label)[label])
        # apply the negative removal to the similarity matrix
        improved_sim_matrix = removal_fn*torch.exp(sim_matrix)

        # compute positive attraction term
        pos_attr = F.relu((self.m_p - smp2proxy)/self.temperature)
        # compute negative repulsion term
        neg_repul = torch.log(torch.exp(-pos_attr) + improved_sim_matrix.sum(dim=1))
        # compute data to data cross-entropy criterion
        criterion = pos_attr + neg_repul
        return criterion.mean()


class PathLengthRegularizer:
    def __init__(self, device, pl_decay=0.01, pl_weight=2, pl_no_weight_grad=False):
        self.pl_decay = pl_decay
        self.pl_weight = pl_weight
        self.pl_mean = torch.zeros([], device=device)
        self.pl_no_weight_grad = pl_no_weight_grad

    def cal_pl_reg(self, fake_images, ws):
        #ws refers to weight style
        #receives new fake_images of original batch (in original implementation, fakes_images used for calculating g_loss and pl_loss is generated independently)
        pl_noise = torch.randn_like(fake_images) / np.sqrt(fake_images.shape[2] * fake_images.shape[3])
        with conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad):
            pl_grads = torch.autograd.grad(outputs=[(fake_images * pl_noise).sum()], inputs=[ws], create_graph=True, only_inputs=True)[0]
        pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
        pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
        self.pl_mean.copy_(pl_mean.detach())
        pl_penalty = (pl_lengths - pl_mean).square()
        loss_Gpl = (pl_penalty * self.pl_weight).mean(0)
        return loss_Gpl


def enable_allreduce(dict_):
    loss = 0
    for key, value in dict_.items():
        if value is not None and key != "label":
            loss += value.mean()*0
    return loss


def d_vanilla(d_logit_real, d_logit_fake, DDP):
    d_loss = torch.mean(F.softplus(-d_logit_real)) + torch.mean(F.softplus(d_logit_fake))
    return d_loss


def g_vanilla(d_logit_fake, DDP):
    return torch.mean(F.softplus(-d_logit_fake))


def d_logistic(d_logit_real, d_logit_fake, DDP):
    d_loss = F.softplus(-d_logit_real) + F.softplus(d_logit_fake)
    return d_loss.mean()


def g_logistic(d_logit_fake, DDP):
    # basically same as g_vanilla.
    return F.softplus(-d_logit_fake).mean()


def d_ls(d_logit_real, d_logit_fake, DDP):
    d_loss = 0.5 * (d_logit_real - torch.ones_like(d_logit_real))**2 + 0.5 * (d_logit_fake)**2
    return d_loss.mean()


def g_ls(d_logit_fake, DDP):
    gen_loss = 0.5 * (d_logit_fake - torch.ones_like(d_logit_fake))**2
    return gen_loss.mean()


def d_hinge(d_logit_real, d_logit_fake, DDP):
    return torch.mean(F.relu(1. - d_logit_real)) + torch.mean(F.relu(1. + d_logit_fake))


def g_hinge(d_logit_fake, DDP):
    return -torch.mean(d_logit_fake)

# ------------------------------------------------------------

def d_acg(d_logit_real, d_logit_fake, adv_pro, y, batch_size, DDP):
    d_loss = torch.mean(F.softplus(-d_logit_real)) + torch.mean(F.relu(-math.log(adv_pro) - F.softplus(-d_logit_fake)))
    # d_loss -= conditional_entropy(y, batch_size)
    return d_loss


def g_acg(d_logit_fake, y, batch_size, DDP):

    g_loss = F.softplus(-d_logit_fake).mean()
    # g_loss += conditional_entropy(y, batch_size)
    return g_loss



# def d_hinge(d_logit_real, d_logit_fake, DDP):
#     return torch.mean(F.relu(1. - d_logit_real)) + torch.mean((1 - torch.sigmoid(d_logit_fake)).pow(2)*F.relu(1. + d_logit_fake))

# def g_hinge(d_logit_fake, DDP):
#     return -torch.mean(d_logit_fake)


def d_wasserstein(d_logit_real, d_logit_fake, DDP):
    return torch.mean(d_logit_fake - d_logit_real)


def g_wasserstein(d_logit_fake, DDP):
    return -torch.mean(d_logit_fake)


# ========================================================

# https://github.com/dlmacedo/entropic-out-of-distribution-detection/blob/HEAD/losses/isomax.py

class IsoMaxLoss(nn.Module):
    """This part replaces the nn.CrossEntropyLoss()"""
    def __init__(self, entropic_scale=10.0):
        super(IsoMaxLoss, self).__init__()
        self.entropic_scale = entropic_scale

    def forward(self, cls_output, label, **_):
        #############################################################################
        #############################################################################
        """Probabilities and logarithms are calculated separately and sequentially"""
        """Therefore, nn.CrossEntropyLoss() must not be used to calculate the loss"""
        #############################################################################
        #############################################################################
        distances = -cls_output
        probabilities_for_training = nn.Softmax(dim=1)(-self.entropic_scale * distances)
        probabilities_at_targets = probabilities_for_training[range(distances.size(0)), label]
        loss = -torch.log(probabilities_at_targets).mean()

        return loss


# https://github.com/Leethony/Additive-Margin-Softmax-Loss-Pytorch/blob/master/AdMSLoss.py

class AdMSoftmaxLoss(nn.Module):

    # def __init__(self, in_features, out_features, s=30.0, m=0.4):
    def __init__(self, s=30.0, m=0.4):

        '''
        AM Softmax Loss
        '''
        super(AdMSoftmaxLoss, self).__init__()
        self.s = s
        self.m = m

        # self.in_features = in_features
        # self.out_features = out_features
        # self.fc = nn.Linear(in_features, out_features, bias=False)

    def forward(self, x, labels):
        '''
        input shape (N, in_features)
        '''
        # assert len(x) == len(labels)
        # assert torch.min(labels) >= 0
        # assert torch.max(labels) < self.out_features
        
        # for W in self.fc.parameters():
        #     W = F.normalize(W, dim=1)

        # x = F.normalize(x, dim=1)

        # wf = self.fc(x)

        wf = x
        numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m)
        excl = torch.cat([torch.cat((wf[i, :y], wf[i, y+1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0)
        denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1)
        L = numerator - torch.log(denominator)
        return -torch.mean(L)

# Modified from https://github.com/KaiyangZhou/pytorch-center-loss/blob/master/center_loss.py

class CenterLoss(nn.Module):
    """Center loss.
    
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes, feat_dim):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        # self.use_gpu = use_gpu

        # if self.use_gpu:
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        # else:
        #     self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        # if self.use_gpu: classes = classes.cuda()
        classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss


# LogitNormLoss--------------------
class ETFLoss(nn.Module):

    def __init__(self, cls_logit_temp, m_p):
        super(ETFLoss, self).__init__()
        self.t = cls_logit_temp
        self.ce_loss = torch.nn.CrossEntropyLoss()
        self.m_p = m_p

    def forward(self, cls_output, label, **_):

        # view(-1, 1)

        label_index = torch.unsqueeze(label, 1).cpu()

        # ------------------   0<cls<m_p<true_class  ---------------------------------------------
        # mp_matrix = torch.zeros(cls_output.shape).scatter_(1, label_index, self.m_p).cuda()
        # cls_m = F.relu(mp_matrix - cls_output)


        # ------------------   cls < 0, m_p<true_class  ---------------------------------------------

        cls_clone = 2*cls_output.clone().cpu()
        output_m = cls_clone.scatter_(1, label_index, self.m_p).cuda()
        cls_m = F.relu(output_m - cls_output)

        # ------------------   cls < m_p < true_class  ---------------------------------------------

        # mp_matrix = torch.zeros(cls_output.shape).scatter_(1, label_index, self.m_p+0.001).cuda()
        # cls_p = mp_matrix - cls_output
        # mask_0 = torch.zeros(cls_output.shape).scatter_(1, label_index, 1).cuda()
        # cls_p = cls_p * mask_0

        # mp_negative = torch.zeros(cls_output.shape).scatter_(1, label_index, -(self.m_p-0.001)).cuda()
        # cls_n = cls_output + mp_negative
        # mask_1 = torch.ones(cls_output.shape).scatter_(1, label_index, 0).cuda()
        # cls_n = cls_n * mask_1

        # cls_m = F.relu(cls_p + cls_n)

        # ---------------------------------------------------------------

        cls_m = cls_m / self.t
        return self.ce_loss(cls_m, label).mean()


# LogitNormLoss--------------------
class LogitNormLoss(nn.Module):

    def __init__(self, cls_logit_temp):
        super(LogitNormLoss, self).__init__()
        self.t = cls_logit_temp
        self.ce_loss = torch.nn.CrossEntropyLoss()


    def forward(self, cls_output, label, **_):
        norms = torch.norm(cls_output, p=2, dim=-1, keepdim=True) + 1e-7
        logit_norm = torch.div(cls_output, norms) / self.t
        return self.ce_loss(logit_norm, label).mean()

# wrong--------------------------------
class G_ProLabel(nn.Module):

    def forward(self, g_prob, d_prob, **_):

        loss = -torch.sum(torch.log(g_prob) * d_prob.detach(), dim=1)
        return loss.mean()

# use pre_prob if classfy true else onehot----------------------------------------------
def g_learnLoss(fake_cls_output, new_label, LogitNorm, logit_temp):


    if LogitNorm:
        norms = torch.norm(fake_cls_output, p=2, dim=-1, keepdim=True) + 1e-7
        logit_norm = torch.div(fake_cls_output, norms) / logit_temp

        loss = -torch.sum(F.log_softmax(logit_norm, dim=-1) * new_label, dim=1)

        return loss.mean()
    else:

        loss = -torch.sum(F.log_softmax(fake_cls_output, dim=-1) * new_label, dim=1)

        return loss.mean()

class LabelSmoothingCrossEntropy(nn.Module):

    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()


        # assert smoothing < 1.0

        self.smoothing = smoothing
        self.confidence = 1. - self.smoothing

    def forward(self, cls_output, label, **_):

        logprobs = F.log_softmax(cls_output, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=label.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)

        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()
    
class NegCrossEntropy(nn.Module):
    
    """
    Confidence Penalty:  is equivalent to
    adding the KL divergence between the model pθ
    and the uniform distribution.
    """
    def __init__(self, pro, label, current_batch, num_class):

        super(NegCrossEntropy, self).__init__()

        uniform_dist = torch.Tensor(current_batch, num_class).fill_(((1.- pro)/(num_class-1))).cuda()
        
        self.label_pro = uniform_dist.scatter_(1, label, pro).cuda()
        

    def forward(self, cls_output, **_):

        logprobs = F.log_softmax(cls_output, dim=-1)      
        loss = (torch.exp(logprobs).mul(logprobs) - torch.exp(logprobs).mul(torch.log(self.label_pro))).sum(dim=-1)
        
        return loss.mean()
    

class LabelCrossEntropy(nn.Module):

    """
    NLL loss with label smoothing.
    """
    def __init__(self, pro, num_class):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelCrossEntropy, self).__init__()

        # uniform_dist = torch.Tensor(current_batch, self.DATA.num_classes).fill_((1./self.DATA.num_classes))

        # assert smoothing < 1.0

        self.smoothing = (1 - pro) / (1 - 1 / num_class)
        self.confidence = 1. - self.smoothing

    def forward(self, cls_output, label, **_):

        logprobs = F.log_softmax(cls_output, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=label.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)

        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

# https://github1s.com/helmy-elrais/Semi_Supervised_Learning
# marginalized entropy
def marginal_entropy(y):
    y1 = y.mean(0)
    y2 = -torch.sum(y1*torch.log(y1+ 1e-7))
    return y2

def conditional_entropy(y, batch_size):
    # y-->softmax output
    y1 = -y*torch.log(y + 1e-7)
    y2 = 1.0/batch_size*y1.sum()
    return y2
    

class SOFTLoss(nn.Module):

    # for numerical stability

    def __init__(self, logit_temp):
        super(SOFTLoss, self).__init__()
        self.t = logit_temp

    def forward(self, valid_real, valid_fake):

        b_num = valid_real.shape[0]
        valid_real = valid_real / self.t
        valid_fake = valid_fake / self.t

        logit_max, _ = torch.max(valid_fake, dim=0, keepdim=True)
        logit_max = torch.tile(logit_max, (b_num, 1))
        logit_max = torch.cat((valid_real, logit_max), 1)
        logit_max, _ = torch.max(logit_max, dim=1, keepdim=True)

        soft_loss = -torch.mean(valid_real-logit_max.detach() - torch.log(torch.exp(valid_real-logit_max.detach()) + torch.sum(torch.exp(valid_fake-logit_max.detach()))))

        return soft_loss


#  rsgan------------------
def rs_gloss(d_logit_real, d_logit_fake, type, DDP):
    if type == 'log':
        scalar = torch.FloatTensor([0]).cuda()
        z = d_logit_real - d_logit_fake
        z_star = torch.max(z, scalar.expand_as(z))
        return (z_star + torch.log(torch.exp(z - z_star) + torch.exp(0 - z_star))).mean()
    elif type == 'hinge':
        return (F.relu(1 + (d_logit_real - d_logit_fake))).mean()

def rs_dloss(d_logit_real, d_logit_fake, type, DDP):
    if type == 'log':
        scalar = torch.FloatTensor([0]).cuda()
        z = d_logit_fake - d_logit_real
        z_star = torch.max(z, scalar.expand_as(z))
        return (z_star + torch.log(torch.exp(z - z_star) + torch.exp(0 - z_star))).mean()
    elif type == 'hinge':
        return (F.relu(1 + (d_logit_fake - d_logit_real))).mean()

# WOOD -------------------------------------------------------------
    # https://github.com/wyn430/WOOD.


def label_2_onehot(label, C, device):
    ##transform the InD labels into one-hot vector
    
    size = label.shape[0]
    if len(label.shape) == 1:
        label = torch.unsqueeze(label, 1)
    
    label = label % C
    
    label_onehot = torch.FloatTensor(size, C).to(device)

    label_onehot.zero_()
    label_onehot.scatter_(1, label, 1)
    
    return label_onehot


class NLLWOOD_Loss(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    
    # def forward(ctx, Input, Target, C, beta, device):
    # def forward(ctx, pre_inputs, Target, C, beta, device):
    def forward(ctx, ood_inputs, C, beta, device):
    
        """
        input: (N,C), N is the batch size, C is the number of Class
        target: (N), 0,...,C-1, for in distribution, C for out of distribution, the data type should be int
        C: the number of class
        """
        LongTensor = torch.cuda.LongTensor 
        FloatTensor = torch.cuda.FloatTensor

        C = torch.LongTensor([C]).to('cuda')
        C = Variable(C.type(LongTensor))
        beta = torch.Tensor([float(beta)]).to('cuda')
        beta = Variable(beta.type(FloatTensor))
         
        # batch_size = pre_inputs.shape[0] // 2
        
        # pre_inputs = pre_inputs.clone()
        OOD_input = ood_inputs.clone()
        # target = Target.clone()
        
        # InD_input = pre_inputs[:batch_size,:]
        # OOD_input = pre_inputs[batch_size:,:]

        # InD_label = target ##only InD samples have labels
        all_class = torch.LongTensor([i for i in range(1)]).to(device)
        
        ##transform the InD labels into one-hot vector
        # InD_label_onehot = label_2_onehot(InD_label, C, device)
        
         ##Loss value for InD samples
        # log_input = InD_input.log()
        # InD_loss = torch.nn.NLLLoss()
        # InD_loss_value = InD_loss(log_input, InD_label)

        # ce_loss = torch.nn.CrossEntropyLoss()
        # InD_loss_value = ce_loss(InD_input, InD_label).mean()

        
        ##Loss value for OOD samples
        all_class_onehot = label_2_onehot(all_class, C, device)

        all_class_onehot = torch.unsqueeze(all_class_onehot, -1)
        OOD_loss = SamplesLoss("sinkhorn", p=2, blur=1.)
        OOD_input = torch.unsqueeze(OOD_input, -1)
        OOD_batch_size = OOD_input.shape[0]
        
        #### elminate min in label####
        all_class_onehot = all_class_onehot.repeat(OOD_batch_size,1,1)

        OOD_loss_value = OOD_loss(OOD_input[:,:,0], OOD_input, all_class_onehot[:,:,0], all_class_onehot).mean()
        
        # ctx.save_for_backward(InD_input, InD_label_onehot, OOD_input, all_class_onehot, beta, C)
        ctx.save_for_backward(OOD_input, all_class_onehot, beta, C)
        
        ####
  
        # loss_value = InD_loss_value - beta * OOD_loss_value
        loss_value = -beta * OOD_loss_value
        
        return loss_value
# This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        #InD_input, InD_label_onehot, OOD_input, all_class_onehot, beta, OOD_ind, InD_ind, C, min_idx_OOD, min_idx_all = ctx.saved_tensors
        # InD_input, InD_label_onehot, OOD_input, all_class_onehot, beta, C= ctx.saved_tensors
        OOD_input, all_class_onehot, beta, C= ctx.saved_tensors
        
    #     InD_input, OOD_input, InD_label_onehot, all_class_onehot, beta, C= ctx.saved_tensors
        
        OOD_loss = SamplesLoss("sinkhorn", p=2, blur=1., potentials=True)
        OOD_batch_size = OOD_input.shape[0]
        OOD_f, OOD_g = OOD_loss(OOD_input[:,:,0], OOD_input, 
                                all_class_onehot[0:1].repeat(OOD_batch_size,1,1)[:,:,0],
                                all_class_onehot[0:1].repeat(OOD_batch_size,1,1))
        
        
        #print(OOD_ind, InD_ind)
        # grad_Input = torch.zeros([InD_batch_size+OOD_batch_size, C]).to('cuda')
        grad_Input = torch.zeros([OOD_batch_size, C]).to('cuda')
        
        grad_Input = -beta * OOD_f
        
        # grad_Input[:InD_batch_size,:] = -InD_label_onehot * (1. / InD_batch_size)
        
        
        return grad_Input, None, None, None, None




# =======================================================================



# ----------------------------------------------------------------
def adv_mh_dg(fake_cls, fake_label, fake_m):

    fake_max, fake_target = re_crammer_d(fake_cls, fake_label)
    
    # return torch.mean(F.relu(math.log((1/0.9-1)/(100-1)) + real_max - real_target)) + torch.mean(F.relu(math.log((1/0.1-1)/(100-1)) - (fake_min - fake_target)))
    
    # real_m = -math.log((1/0.9-1)/(100-1))
    # fake_m = math.log((1/0.02-1)/(100-1))
    # fake_m = math.log((1/0.03-1)/(100-1))

    return torch.mean(F.relu(fake_m - (fake_max - fake_target)))


def adv_mh_d(real_cls, real_label, fake_cls, fake_label):
    
    real_max, real_target = re_crammer_d(real_cls, real_label)
    fake_max, fake_target = re_crammer_d(fake_cls, fake_label)
    

    fake_m = 0.0
    return torch.mean(real_max - real_target) + torch.mean(F.relu(fake_m - (fake_max - fake_target)))
    # return torch.mean(F.relu(fake_m - (fake_min - fake_target)))

def adv_mh_g(fake_cls, fake_label):
    
    fake_max, fake_target = re_crammer_d(fake_cls, fake_label)
    
    return torch.mean(fake_max - fake_target)
  
def sum_crammer(cls_output, label, fake_m):
    # https://github.com/ilyakava/BigGAN-PyTorch/blob/master/train_fns.py
    # crammer singer criterion
    num_real_classes = cls_output.shape[1] - 1
    mask = torch.ones_like(cls_output).to(cls_output.device)
    mask.scatter_(1, label.unsqueeze(-1), 0)  
    wrongs = torch.masked_select(cls_output, mask.bool()).reshape(cls_output.shape[0], num_real_classes)

    target = cls_output.gather(1, label.unsqueeze(-1))
    loss = torch.logsumexp(F.relu(fake_m - (wrongs-target)), dim=1)
    
    return torch.mean(loss)
    
def re_crammer_d(cls_output, label):
    # https://github.com/ilyakava/BigGAN-PyTorch/blob/master/train_fns.py
    # crammer singer criterion
    num_real_classes = cls_output.shape[1] - 1
    mask = torch.ones_like(cls_output).to(cls_output.device)
    mask.scatter_(1, label.unsqueeze(-1), 0)  
    wrongs = torch.masked_select(cls_output, mask.bool()).reshape(cls_output.shape[0], num_real_classes)
    max_wrong, _ = wrongs.max(1)
    max_wrong = max_wrong.unsqueeze(-1)
    target = cls_output.gather(1, label.unsqueeze(-1))
    return max_wrong, target

def re_crammer_g(cls_output, label):
    # https://github.com/ilyakava/BigGAN-PyTorch/blob/master/train_fns.py
    # crammer singer criterion
    num_real_classes = cls_output.shape[1] - 1
    mask = torch.ones_like(cls_output).to(cls_output.device)
    mask.scatter_(1, label.unsqueeze(-1), 0)  
    wrongs = torch.masked_select(cls_output, mask.bool()).reshape(cls_output.shape[0], num_real_classes)
    min_wrong, _ = wrongs.min(1)
    min_wrong = min_wrong.unsqueeze(-1)
    target = cls_output.gather(1, label.unsqueeze(-1))
    return min_wrong, target
# ----------------------------------------------------------------


def crammer_singer_loss(adv_output, label, DDP, **_):
    # https://github.com/ilyakava/BigGAN-PyTorch/blob/master/train_fns.py
    # crammer singer criterion
    num_real_classes = adv_output.shape[1] - 1
    mask = torch.ones_like(adv_output).to(adv_output.device)
    mask.scatter_(1, label.unsqueeze(-1), 0)
    wrongs = torch.masked_select(adv_output, mask.bool()).reshape(adv_output.shape[0], num_real_classes)
    max_wrong, _ = wrongs.max(1)
    max_wrong = max_wrong.unsqueeze(-1)
    target = adv_output.gather(1, label.unsqueeze(-1))
    return torch.mean(F.relu(1 + max_wrong - target))


def feature_matching_loss(real_embed, fake_embed):
    # https://github.com/ilyakava/BigGAN-PyTorch/blob/master/train_fns.py
    # feature matching criterion
    fm_loss = torch.mean(torch.abs(torch.mean(fake_embed, 0) - torch.mean(real_embed, 0)))
    return fm_loss


def lecam_reg(d_logit_real, d_logit_fake, ema):
    reg = torch.mean(F.relu(d_logit_real - ema.D_fake).pow(2)) + \
          torch.mean(F.relu(ema.D_real - d_logit_fake).pow(2))
    return reg


def cal_deriv(inputs, outputs, device):
    grads = autograd.grad(outputs=outputs,
                          inputs=inputs,
                          grad_outputs=torch.ones(outputs.size()).to(device),
                          create_graph=True,
                          retain_graph=True,
                          only_inputs=True)[0]
    return grads


def latent_optimise(zs, fake_labels, generator, discriminator, batch_size, lo_rate, lo_steps, lo_alpha, lo_beta, eval,
                    cal_trsp_cost, device):
    for step in range(lo_steps - 1):
        drop_mask = (torch.FloatTensor(batch_size, 1).uniform_() > 1 - lo_rate).to(device)

        zs = autograd.Variable(zs, requires_grad=True)
        fake_images = generator(zs, fake_labels, eval=eval)
        fake_dict = discriminator(fake_images, fake_labels, eval=eval)
        z_grads = cal_deriv(inputs=zs, outputs=fake_dict["adv_output"], device=device)
        z_grads_norm = torch.unsqueeze((z_grads.norm(2, dim=1)**2), dim=1)
        delta_z = lo_alpha * z_grads / (lo_beta + z_grads_norm)
        zs = torch.clamp(zs + drop_mask * delta_z, -1.0, 1.0)

        if cal_trsp_cost:
            if step == 0:
                trsf_cost = (delta_z.norm(2, dim=1)**2).mean()
            else:
                trsf_cost += (delta_z.norm(2, dim=1)**2).mean()
        else:
            trsf_cost = None
        return zs, trsf_cost


def cal_grad_penalty(real_images, real_labels, fake_images, discriminator, device):
    batch_size, c, h, w = real_images.shape
    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, real_images.nelement() // batch_size).contiguous().view(batch_size, c, h, w)
    alpha = alpha.to(device)

    real_images = real_images.to(device)
    interpolates = alpha * real_images + ((1 - alpha) * fake_images)
    interpolates = interpolates.to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    fake_dict = discriminator(interpolates, real_labels, eval=False)
    grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
    grads = grads.view(grads.size(0), -1)

    grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
    return grad_penalty


def cal_dra_penalty(real_images, real_labels, discriminator, device):
    batch_size, c, h, w = real_images.shape
    alpha = torch.rand(batch_size, 1, 1, 1)
    alpha = alpha.to(device)

    real_images = real_images.to(device)
    differences = 0.5 * real_images.std() * torch.rand(real_images.size()).to(device)
    interpolates = real_images + (alpha * differences)
    interpolates = interpolates.to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    fake_dict = discriminator(interpolates, real_labels, eval=False)
    grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
    grads = grads.view(grads.size(0), -1)

    grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
    return grad_penalty


def cal_maxgrad_penalty(real_images, real_labels, fake_images, discriminator, device):
    batch_size, c, h, w = real_images.shape
    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, real_images.nelement() // batch_size).contiguous().view(batch_size, c, h, w)
    alpha = alpha.to(device)

    real_images = real_images.to(device)
    interpolates = alpha * real_images + ((1 - alpha) * fake_images)
    interpolates = interpolates.to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    fake_dict = discriminator(interpolates, real_labels, eval=False)
    grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
    grads = grads.view(grads.size(0), -1)

    maxgrad_penalty = torch.max(grads.norm(2, dim=1)**2) + interpolates[:,0,0,0].mean()*0
    return maxgrad_penalty


def cal_r1_reg(adv_output, images, device):
    batch_size = images.size(0)
    grad_dout = cal_deriv(inputs=images, outputs=adv_output.sum(), device=device)
    grad_dout2 = grad_dout.pow(2)
    assert (grad_dout2.size() == images.size())
    r1_reg = 0.5 * grad_dout2.contiguous().view(batch_size, -1).sum(1).mean(0) + images[:,0,0,0].mean()*0
    return r1_reg


def adjust_k(current_k, topk_gamma, sup_k):
    current_k = max(current_k * topk_gamma, sup_k)
    return current_k


def normal_nll_loss(x, mu, var):
    # https://github.com/Natsu6767/InfoGAN-PyTorch/blob/master/utils.py
    # Calculate the negative log likelihood of normal distribution.
    # Needs to be minimized in InfoGAN. (Treats Q(c]x) as a factored Gaussian)
    logli = -0.5 * (var.mul(2 * np.pi) + 1e-6).log() - (x - mu).pow(2).div(var.mul(2.0) + 1e-6)
    nll = -(logli.sum(1).mean())
    return nll


def stylegan_cal_r1_reg(adv_output, images):
    with conv2d_gradfix.no_weight_gradients():
        r1_grads = torch.autograd.grad(outputs=[adv_output.sum()], inputs=[images], create_graph=True, only_inputs=True)[0]
    r1_penalty = r1_grads.square().sum([1,2,3]) / 2
    return r1_penalty.mean()
