import torch

class ContrastiveLoss:
    def __call__(self, logits, tgt, reduction="mean", **kwargs):
        return torch.nn.functional.cross_entropy(logits, tgt.squeeze(), reduction=reduction)
