

import torch.nn.functional as F



def feature_loss(feature_stu, feature_tea ):
    loss_all = 0
    for i in range(len(feature_stu)):
            loss_all += F.mse_loss(feature_stu[i], feature_tea[i].detach())
    return loss_all


def logits_loss(outputs,  teacher_outputs, T = 1):
    """
    loss function for Knowledge Distillation (KD)
    """
    D_KL = F.kl_div(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1),reduction='batchmean') * (T * T)
    return D_KL