from torch.nn import Module
import torch
import torch.nn as nn
import numpy as np
import math


# 定义一些用于修改 LOSS的小函数
def compute_distances(x):
    x_norm = (x ** 2).sum(1).view(-1, 1)
    x_t = torch.transpose(x, 0, 1)
    x_t_norm = x_norm.view(1, -1)
    dist = x_norm + x_t_norm - 2.0 * torch.mm(x, x_t)
    dist = torch.clamp(dist, 0, np.inf)
    return dist

def KDE_XT_estimation(logvar_t, mean_t):
    n_batch, d = mean_t.shape
    var = torch.exp(logvar_t) + 1e-10
    normalization_constant = math.log(n_batch)
    dist = compute_distances(mean_t)
    distance_contribution = - torch.mean(torch.logsumexp(input=- 0.5 * dist / var, dim=1))
    XT = normalization_constant + distance_contribution
    return XT


class Myloss(Module):
    def __init__(self,args):
        super(Myloss,self).__init__()
        self.args = args
        self.timesteps = self.args.timesteps
        # 修改 LOSS 函数
        self.logvar_t = -1.0
        self.HY = np.log(self.args.num_classes)  # 类别数求对数赋值给HY
        self.logvar_t = nn.Parameter(torch.Tensor([self.logvar_t])).to(self.args.device)

    def get_XT(self, mean_t):
        XT = KDE_XT_estimation(self.logvar_t, mean_t)  # in natts
        XT = XT / np.log(2)  # in bits
        return XT.to(self.args.device)

    def get_TY(self, logits_y, y):
        criteria = nn.CrossEntropyLoss()
        HY_given_T = criteria(logits_y, y.long())
        TY = (self.HY - HY_given_T) / np.log(2)
        return TY.to(self.args.device), HY_given_T.to(self.args.device)

    def get_loss(self, XT_upper, TY_lower, HY_given_T):
        loss = 1.0 * (HY_given_T - 0.003 * XT_upper)  ###XT_upper前面的系数为超参数，可以调整
        return loss.to(self.args.device)


    def forward(self,outputs,mean_t,labels):

        XT_t = self.get_XT(mean_t) / self.timesteps
        TY_t, HY_given_T = self.get_TY(outputs, labels)
        loss = self.get_loss(XT_t, TY_t, HY_given_T)

        return loss.to(self.args.device)
