# utils.py
import random
import torch
import torch.nn.functional as F

def set_seed(seed: int):
    """Fix random seed to ensure reproducibility"""
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def node_contrast_loss(z1, z2, tau=0.5):
    """
    Function to calculate node-level contrastive loss used in GRACE.
    z1, z2: Node embeddings obtained from two augmented views of the same graph [N, d]
    """
    # L2 normalization
    z1_norm = F.normalize(z1, p=2, dim=1)
    z2_norm = F.normalize(z2, p=2, dim=1)
    # Similarity within the same view
    refl_sim = torch.exp(torch.mm(z1_norm, z1_norm.t()) / tau)
    # Similarity between different views
    between_sim = torch.exp(torch.mm(z1_norm, z2_norm.t()) / tau)
    # Positive examples are diagonal elements (between two views of the same node)
    pos_sim = torch.diag(between_sim)
    # Denominator: sum of all similarities (excluding diagonal) + sum of all between-view similarities
    denom = (refl_sim.sum(dim=1) - torch.diag(refl_sim)) + between_sim.sum(dim=1)
    loss = -torch.log(pos_sim / denom)
    return loss.mean()

def symmetric_node_contrast_loss(z1, z2, tau=0.5):
    """Average of losses in both directions: z1→z2 and z2→z1"""
    loss1 = node_contrast_loss(z1, z2, tau)
    loss2 = node_contrast_loss(z2, z1, tau)
    return 0.5 * (loss1 + loss2)
