# utils.py
import random
import torch
import torch.nn.functional as F

def set_seed(seed: int):
    """Set random seed to ensure reproducibility"""
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def nt_xent_loss(z1, z2, temperature):
    """
    Calculate NT-Xent Loss
    Args:
        z1, z2: Embeddings of different views of the same graph
        temperature: Scaling parameter
    """
    batch_size = z1.shape[0]
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    z = torch.cat([z1, z2], dim=0)
    sim_matrix = torch.matmul(z, z.T) / temperature
    mask = torch.eye(2 * batch_size, device=z.device).bool()
    sim_matrix.masked_fill_(mask, -9e15)
    positive_indices = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(0, batch_size)
    ]).to(z.device)
    loss = F.cross_entropy(sim_matrix, positive_indices)
    return loss
