import torch
import torch.nn.functional as F


def nt_xent_loss(proj_pairs, temperature=0.5):
    """Compute NT-Xent loss."""
    batch_size = proj_pairs.shape[0]
    assert batch_size % 2 == 0, "Batch size should be even after augmentations"

    # Compute pairwise cosine similarity
    sim_matrix = F.cosine_similarity(proj_pairs.unsqueeze(1), proj_pairs.unsqueeze(0), dim=-1)

    # Split the batch into two sets of projections (from two different augmentations)
    sim_matrix = sim_matrix.view(batch_size // 2, 2, batch_size // 2, 2)

    # Compute positive similarity (between two augmentations of the same image)
    pos_sim = torch.cat([sim_matrix[:, 0, :, 1].diag(), sim_matrix[:, 1, :, 0].diag()], dim=0).view(batch_size, 1)

    # Extract the diagonal for negative samples
    neg_mask = torch.eye(batch_size // 2).bool().to(pos_sim.device)
    neg_sim = sim_matrix.masked_select(~neg_mask.view(batch_size // 2, 1, batch_size // 2, 1)).view(batch_size, -1)

    # Concatenate positive and negative samples, then compute the logits
    logits = torch.cat([pos_sim, neg_sim], dim=1)
    logits /= temperature

    # Create labels: the positive sample is always at index 0
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(pos_sim.device)

    return F.cross_entropy(logits, labels)

