import torch
import torch.nn.functional as F
import torch.nn as nn


def KL_loss(feature_stu, feature_tea,T = 1 ):
    B = feature_stu.shape[0]
    feature_stu = feature_stu.reshape(B, -1)
    feature_tea = feature_tea.reshape(B, -1)
    loss = F.kl_div(F.log_softmax(feature_stu / T, dim=1), F.softmax(feature_tea / T, dim=1),reduction='batchmean') * T * T
    return loss


def feature_loss(feature_stu, feature_tea, fun = 'mse', T = 20 ):
    loss_all = 0
    for i in range(len(feature_stu)):
        if fun == 'mse':
            loss_all += F.mse_loss(feature_stu[i], feature_tea[i].detach())
        elif fun == 'l1':
            loss_all += F.l1_loss(feature_stu[i],feature_tea[i].detach())
        elif fun == 'kl':
            loss_all += KL_loss(feature_stu[i], feature_tea[i].detach(),T)
    return loss_all

def logits_loss(outputs,  teacher_outputs, T = 1):
    """
    loss function for Knowledge Distillation (KD)
    """
    D_KL = F.kl_div(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1),reduction='batchmean') * (T * T)
    return D_KL

# class Norm_loss(nn.Module):
#     def __init__(self,dim = [64,128,256,512],layer = [1,1,1,1]):
#         super(Norm_loss, self).__init__()
#         self.bn = nn.ModuleList()
#         for channel in dim:
#             self.tmp = nn.BatchNorm2d(channel,affine=False)
#             self.tmp.requires_grad_(requires_grad=False)
#             self.bn.append(self.tmp)
#         self.layer = layer
#
#     def forward(self,fea_stu,fea_tea):
#         loss = 0
#         h = 0
#         feature_snn = []
#         feature_cnn = []
#         for i in range(len(self.layer)):
#             for j in range(self.layer[i]):
#                 f1 = self.bn[i](fea_stu[h + j])
#                 f2 = self.bn[i](fea_tea[h + j])
#                 loss += F.mse_loss(f1,f2)
#                 feature_snn.append(F.normalize(f1.view(f1.shape[0],-1),dim=1))
#                 feature_cnn.append(F.normalize(f2.view(f1.shape[0],-1),dim=1))
#             h += self.layer[i]
#         return loss,feature_snn,feature_cnn



