import torch.nn.functional as F
from torch import nn


class KLDiv(nn.Module):
    def forward(self, student_logits, teacher_logits, temperature):
        inputs = F.log_softmax(student_logits / temperature, dim=-1)
        targets = F.log_softmax(teacher_logits / temperature, dim=-1)
        return F.kl_div(inputs, targets, reduction="batchmean", log_target=True)
