import torch
import torch.nn as nn
import torch.nn.functional as F


class DAxisDistillLoss(nn.Module):
    def __init__(self):
        super(DAxisDistillLoss, self).__init__()

    def forward(self, feature, target_feature):
        # feature: (B, T, D)
        cosine_sim = F.cosine_similarity(feature, target_feature, dim=1)
        distill_loss = - torch.log(torch.sigmoid(cosine_sim)).mean()
        return distill_loss


if __name__ == '__main__':
    loss_fn = DAxisDistillLoss()
    feature = torch.randn(10, 48, 768)
    target_feature = torch.randn(10, 48, 768)

    loss = loss_fn(feature, target_feature)
    print(loss)
