import torch
import torch.nn as nn
import torch.nn.functional as F

class VT(nn.Module):
    """Total Variation (TV) Loss for Knowledge Distillation"""

    def __init__(self):
        super(VT, self).__init__()

    def forward(self, logits_s, logits_t):
        """
        Compute the Total Variation (TV) loss between student and teacher logits.
        
        Args:
            logits_s (torch.Tensor): Logits from the student model (batch_size, num_classes).
            logits_t (torch.Tensor): Logits from the teacher model (batch_size, num_classes).
        
        Returns:
            torch.Tensor: The TV loss between the student and teacher models.
        """
        # 计算学生和教师模型的 softmax 概率分布
        prob_s = F.softmax(logits_s, dim=1)
        prob_t = F.softmax(logits_t, dim=1)
        
        # 计算 TV 损失
        tv_loss = 0.5 * torch.sum(torch.abs(prob_s - prob_t), dim=1).mean()
        return tv_loss
