# utils.py
import random
import torch
import torch.nn.functional as F

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def symmetric_node_contrast_loss_stable(z1, z2, tau=0.5):
    return 0.5 * (node_contrast_loss_stable(z1, z2, tau) +
                  node_contrast_loss_stable(z2, z1, tau))

def node_contrast_loss_stable(z1, z2, tau=0.5):

    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    logits_11 = (z1 @ z1.t()) / tau            # [N,N]
    logits_12 = (z1 @ z2.t()) / tau            # [N,N]

    n = logits_11.size(0)
    mask_eye = torch.eye(n, dtype=torch.bool, device=z1.device)
    logits_11 = logits_11.masked_fill(mask_eye, float('-inf'))

    pos = torch.diag(logits_12)                # [N]
    denom = torch.logsumexp(torch.cat([logits_11, logits_12], dim=1), dim=1)
    loss = -(pos - denom).mean()
    return loss

def apply_edge_keep_mask(edge_attr, keep_mask):
    if edge_attr is None:
        edge_attr = torch.ones_like(keep_mask)
    if edge_attr.dim() == 1:
        edge_attr = edge_attr.unsqueeze(-1)
    return edge_attr * keep_mask
