# utils.py
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
import random

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def _node_contrast_loss(z1, z2, tau=0.5):
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    refl = torch.exp((z1 @ z1.t()) / tau)           # same-view negatives
    cross = torch.exp((z1 @ z2.t()) / tau)          # cross-view pos+negs
    pos   = torch.diag(cross)
    denom = (refl.sum(dim=1) - torch.diag(refl)) + cross.sum(dim=1)
    loss  = -torch.log(pos / (denom + 1e-8))
    return loss.mean()

def symmetric_node_contrast_loss(z1, z2, tau=0.5):
    return 0.5 * (_node_contrast_loss(z1, z2, tau) + _node_contrast_loss(z2, z1, tau))


def drop_edges(data: Data, p: float = 0.2) -> Data:
    ei, ea = data.edge_index, data.edge_attr
    E = ei.size(1)
    keep = (torch.rand(E, device=ei.device) > p)
    new_ei = ei[:, keep]
    new_ea = ea[keep] if ea is not None else None
    return Data(x=data.x, edge_index=new_ei, edge_attr=new_ea)

def mask_features(data: Data, p: float = 0.2) -> Data:
    x = data.x.clone()
    mask = (torch.rand_like(x) < p)
    x[mask] = 0
    return Data(x=x, edge_index=data.edge_index, edge_attr=data.edge_attr)

def make_view_RE_MF(data: Data, p_re=0.1, p_mf=0.1) -> Data:
    v = data
    if p_re > 0: v = drop_edges(v, p_re)
    if p_mf > 0: v = mask_features(v, p_mf)
    if getattr(v, "edge_attr", None) is not None and v.edge_attr.dim() == 1:
        v.edge_attr = v.edge_attr.unsqueeze(-1)
    return v


def info_regularizer_cos(H1, H2, H_ref):
    def diag_cos(A, B):
        A = F.normalize(A, p=2, dim=1); B = F.normalize(B, p=2, dim=1)
        return (A * B).sum(dim=1)  # elementwise diag
    s12 = diag_cos(H1, H2)
    s1r = diag_cos(H1, H_ref)
    s2r = diag_cos(H2, H_ref)
    d = 2*s12 - (s1r + s2r)
    return F.relu(d).mean()


@torch.no_grad()
def _project_l1_to_budget(w, w0, budget):
    if budget is None or budget <= 0:
        return w.clamp(0.0, 1.0)
    diff = (w - w0)
    norm1 = diff.abs().sum()
    if norm1 > budget:
        w = w0 + diff * (budget / (norm1 + 1e-8))
    return w.clamp(0.0, 1.0)

def build_adv_edge_index_from_candidates(cands_src, cands_dst, w, thr=0.5, undirected=True):
    keep = (w > thr)
    src = cands_src[keep]; dst = cands_dst[keep]
    if undirected:
        mask = (src != dst)
        src = src[mask]; dst = dst[mask]
        src2 = torch.cat([src, dst], dim=0)
        dst2 = torch.cat([dst, src], dim=0)
        return torch.stack([src2, dst2], dim=0)
    else:
        mask = (src != dst)
        return torch.stack([src[mask], dst[mask]], dim=0)
