from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F


def normalize(logit):
    mean = logit.mean(dim=-1, keepdims=True)
    stdv = logit.std(dim=-1, keepdims=True)
    return (logit - mean) / (1e-7 + stdv)


class LSKDLoss(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, temp):
        super(LSKDLoss, self).__init__()
        self.T = temp
        
    def forward(self, y_s, y_t):
        KD_loss = 0
        KD_loss += nn.KLDivLoss(reduction='batchmean')(F.log_softmax(normalize(y_s)/self.T, dim=1), F.softmax(normalize(y_t)/self.T, dim=1)) * self.T * self.T
        
        return KD_loss