import torch
import torch.nn.functional as F

def nt_xent_loss(z1, z2, tau=0.5, eps=1e-8):
    """
    Symmetric NT-Xent on two batches of graph-level embeddings z1,z2: shape [B, D]
    """
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    batch_size = z1.size(0)

    reps = torch.cat([z1, z2], dim=0)  # [2B, D]
    sim_matrix = torch.matmul(reps, reps.T) / tau  # [2B,2B]

    # mask self-similarity
    diag_mask = torch.eye(2 * batch_size, device=z1.device).bool()
    exp_sim = torch.exp(sim_matrix) * (~diag_mask)  # zero out diag

    # positive pairs: i <-> i+batch
    pos = torch.exp(torch.sum(z1 * z2, dim=1) / tau)  # [B]
    positives = torch.cat([pos, pos], dim=0)  # [2B]

    denom = exp_sim.sum(dim=1)  # [2B]
    loss = -torch.log(positives / (denom + eps) + eps)
    return loss.mean()
