import torch

class NoiseContrastiveLoss:
    def __call__(self, symbol_emb, context_emb, reduction="mean", **kwargs):
        logits  = symbol_emb @ context_emb.transpose(0,1)
        targets = torch.arange(symbol_emb.size(0),device=symbol_emb.device)
        return torch.nn.functional.cross_entropy(logits, targets, reduction=reduction)


