# utils.py
import random
import torch
import torch.nn.functional as F

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def node_contrast_loss(z1, z2, tau=0.5):

    z1_norm = F.normalize(z1, p=2, dim=1)
    z2_norm = F.normalize(z2, p=2, dim=1)

    refl_sim = torch.exp(torch.mm(z1_norm, z1_norm.t()) / tau)

    between_sim = torch.exp(torch.mm(z1_norm, z2_norm.t()) / tau)

    pos_sim = torch.diag(between_sim)
    denom = (refl_sim.sum(dim=1) - torch.diag(refl_sim)) + between_sim.sum(dim=1)
    loss = -torch.log(pos_sim / denom)
    return loss.mean()

def symmetric_node_contrast_loss(z1, z2, tau=0.5):
    loss1 = node_contrast_loss(z1, z2, tau)
    loss2 = node_contrast_loss(z2, z1, tau)
    return 0.5 * (loss1 + loss2)
