
import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, temperature=2.0):
    """
    KL-divergence loss for distillation.
    """
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
    return loss
