import torch
import torch.nn.functional as F
from einops import rearrange, einsum, repeat

def nt_xent_loss(queries, keys, temperature = 0.1):
    '''
    Compute the contrastive loss between queries and keys.
    queries: torch.Tensor, shape [batch, slot, dim]
    keys: torch.Tensor, shape [batch, slot, dim]
    '''
    batch, b, device = keys.shape[0], keys.shape[-2], keys.device

    n = b * 2
    projs = torch.cat((queries, keys), dim=-2)
    projs = F.normalize(projs, dim=-1)
    logits = einsum(projs, projs, 'batch k d, batch l d -> batch k l')

    mask = torch.eye(n, device=device).bool()
    logits = rearrange(logits[..., ~mask], 
                        'batch (k l) -> (batch k) l', k = n, l = n - 1)
    logits /= temperature

    labels = torch.cat(((torch.arange(b, device = device) + b - 1), 
                        torch.arange(b, device=device)), dim=0)
    labels = repeat(labels, 'n -> (batch n)', batch = batch)
    loss = F.cross_entropy(logits, labels)
    return loss