# augmentations.py

import torch
import networkx as nx
from torch_geometric.utils import (
    dropout_adj,
    degree,
    to_undirected,
    to_networkx,
)
from torch_geometric.data import Data


def compute_pr(data, alpha=0.85, max_iter=100):
    G = to_networkx(data, to_undirected=True)
    pr_dict = nx.pagerank(G, alpha=alpha, max_iter=max_iter)
    pr = torch.tensor([pr_dict[i] for i in range(data.num_nodes)],
                      dtype=torch.float, device=data.x.device)
    return pr

def eigenvector_centrality(data, max_iter=100):
    G = to_networkx(data, to_undirected=True)
    evc_dict = nx.eigenvector_centrality_numpy(G, max_iter=max_iter)
    evc = torch.tensor([evc_dict[i] for i in range(data.num_nodes)],
                       dtype=torch.float, device=data.x.device)
    return evc


def degree_drop_weights(edge_index, num_nodes=None):
    if num_nodes is None:
        num_nodes = int(edge_index.max().item()) + 1
    und = to_undirected(edge_index)
    deg = degree(und[1], num_nodes=num_nodes, dtype=torch.float)
    col_deg = deg[edge_index[1]]
    s = torch.log(col_deg + 1e-8)
    return (s.max() - s) / (s.max() - s.mean())

def pr_drop_weights(data):
    pr = compute_pr(data)
    pr_col = pr[data.edge_index[1]]
    s = torch.log(pr_col + 1e-8)
    return (s.max() - s) / (s.max() - s.mean())

def evc_drop_weights(data):
    evc = eigenvector_centrality(data)
    evc_col = evc[data.edge_index[1]]
    s = torch.log(evc_col + 1e-8)
    return (s.max() - s) / (s.max() - s.mean())


def feature_drop_weights(data, node_c):
    x = data.x
    w = x.abs().t() @ node_c        # [F]
    s = torch.log(w + 1e-8)
    return (s.max() - s) / (s.max() - s.mean())


def drop_edge_weighted(edge_index, edge_attr, weights, p, threshold=0.7, eps=1e-6):
    mean_w = weights.mean()
    w = weights / (mean_w + eps) * p
    w = w.clamp(min=0.0, max=threshold)
    w = w.nan_to_num(0.0)

    keep_prob = (1.0 - w).clamp(min=0.0, max=1.0)
    sel_mask = torch.bernoulli(keep_prob).to(torch.bool)

    ei = edge_index[:, sel_mask]
    ea = edge_attr[sel_mask] if edge_attr is not None else None
    return ei, ea

def drop_feature_weighted(x, weights, p, threshold=0.7, eps=1e-6):
    mean_w = weights.mean()
    w = weights / (mean_w + eps) * p
    w = w.clamp(min=0.0, max=threshold)
    w = w.nan_to_num(0.0)

    drop_mask = torch.bernoulli(w).to(torch.bool)

    x2 = x.clone()
    x2[:, drop_mask] = 0.0
    return x2

def drop_feature(x, drop_prob):
    mask = torch.rand(x.size(1), device=x.device) < drop_prob
    x2 = x.clone()
    x2[:, mask] = 0
    return x2


def augment_graph(
    data: Data,
    drop_scheme: str,
    drop_edge_rate_1: float,
    drop_edge_rate_2: float,
    drop_feature_rate_1: float,
    drop_feature_rate_2: float,
    threshold: float = 0.7
) -> (Data, Data):

    if drop_scheme == 'uniform':
        ei1 = dropout_adj(data.edge_index, p=drop_edge_rate_1)[0]
        ei2 = dropout_adj(data.edge_index, p=drop_edge_rate_2)[0]
        x1 = drop_feature(data.x, drop_feature_rate_1)
        x2 = drop_feature(data.x, drop_feature_rate_2)
        return (
            Data(x=x1, edge_index=ei1),
            Data(x=x2, edge_index=ei2),
        )


    if drop_scheme == 'degree':
        ew = degree_drop_weights(data.edge_index, num_nodes=data.num_nodes)
        node_c = degree(to_undirected(data.edge_index)[1],
                        num_nodes=data.num_nodes)
    elif drop_scheme == 'pr':
        ew = pr_drop_weights(data)
        node_c = compute_pr(data)
    elif drop_scheme == 'evc':
        ew = evc_drop_weights(data)
        node_c = eigenvector_centrality(data)
    else:
        raise ValueError(f"Unknown drop_scheme: {drop_scheme}")

    fw = feature_drop_weights(data, node_c)

    # View1
    ei1, ea1 = drop_edge_weighted(
        data.edge_index, data.edge_attr, ew,
        drop_edge_rate_1, threshold=threshold
    )
    x1 = drop_feature_weighted(
        data.x, fw, drop_feature_rate_1, threshold=threshold
    )

    # View2
    ei2, ea2 = drop_edge_weighted(
        data.edge_index, data.edge_attr, ew,
        drop_edge_rate_2, threshold=threshold
    )
    x2 = drop_feature_weighted(
        data.x, fw, drop_feature_rate_2, threshold=threshold
    )

    return (
        Data(x=x1, edge_index=ei1, edge_attr=ea1),
        Data(x=x2, edge_index=ei2, edge_attr=ea2),
    )
