import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.registry import MODELS


@MODELS.register_module()
class NKDLoss(nn.Module):

    """ Changed into WTTM """

    def __init__(self,
                 name,
                 use_this,
                 temp=1.8,
                 gamma=0.9,
                 ):
        super(NKDLoss, self).__init__()

        self.temp = temp        # it is used as the β in our paper
        self.gamma = gamma      # it is used as the γ (or 1/T) in our paper

    def forward(self, logit_s, logit_t, gt_label):
        
        p_s = F.log_softmax(logit_s, dim=1)
        p_t = torch.pow(torch.softmax(logit_t, dim=1), self.gamma)
        norm = torch.sum(p_t, dim=1)
        p_t = p_t / norm.unsqueeze(1)
        KL = torch.sum(F.kl_div(p_s, p_t, reduction='none'), dim=1)
        kd_loss = torch.mean(norm*KL)

        return self.temp * kd_loss