# utils.py
import random
import torch
import torch.nn.functional as F

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

# ===== Gumbel-Softmax（hard=Straight-Through） =====
def gumbel_softmax(logits: torch.Tensor, tau: float = 1.0, hard: bool = True):
    g = -torch.empty_like(logits).exponential_().log()  # Gumbel(0,1)
    y = F.softmax((logits + g) / max(tau, 1e-6), dim=-1)
    if hard:
        idx = y.argmax(dim=-1, keepdim=True)
        y_hard = torch.zeros_like(y).scatter_(-1, idx, 1.0)
        # Straight-through
        y = (y_hard - y).detach() + y
    return y

def node_contrast_loss(z1, z2, tau=0.5):
    z1_norm = F.normalize(z1, p=2, dim=1)
    z2_norm = F.normalize(z2, p=2, dim=1)
    refl_sim = torch.exp(torch.mm(z1_norm, z1_norm.t()) / tau)      
    between_sim = torch.exp(torch.mm(z1_norm, z2_norm.t()) / tau)   
    pos_sim = torch.diag(between_sim)
    denom = (refl_sim.sum(dim=1) - torch.diag(refl_sim)) + between_sim.sum(dim=1)
    loss = -torch.log(pos_sim / (denom + 1e-8))
    return loss.mean()

def symmetric_node_contrast_loss(z1, z2, tau=0.5):
    return 0.5 * (node_contrast_loss(z1, z2, tau) + node_contrast_loss(z2, z1, tau))

def view_similarity_loss(A1: torch.Tensor, A2: torch.Tensor, reduce: str = "mean"):
    a1 = F.normalize(A1, p=2, dim=-1)
    a2 = F.normalize(A2, p=2, dim=-1)
    sim = (a1 * a2).sum(dim=-1)  # per-node cosine
    if reduce == "mean":
        return sim.mean()
    elif reduce == "sum":
        return sim.sum()
    else:
        return sim
