import torch
from torch_sparse import mul
from torch_sparse import sum as sparsesum
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import torch_geometric
from torch_sparse import SparseTensor 
import numpy as np
from sklearn.preprocessing import normalize
from networkx import DiGraph
from torch_geometric.utils import from_networkx
import torch


def row_norm(adj):
    """
    Applies the row-wise normalization:
        \mathbf{D}_{out}^{-1} \mathbf{A}
    """
    row_sum = sparsesum(adj, dim=1)

    return mul(adj, 1 / row_sum.view(-1, 1))


def directed_norm(adj):
    """
    Applies the normalization for directed graphs:
        \mathbf{D}_{out}^{-1/2} \mathbf{A} \mathbf{D}_{in}^{-1/2}.
    """
    in_deg = sparsesum(adj, dim=0)
    in_deg_inv_sqrt = in_deg.pow_(-0.5)
    in_deg_inv_sqrt.masked_fill_(in_deg_inv_sqrt == float("inf"), 0.0)

    out_deg = sparsesum(adj, dim=1)
    out_deg_inv_sqrt = out_deg.pow_(-0.5)
    out_deg_inv_sqrt.masked_fill_(out_deg_inv_sqrt == float("inf"), 0.0)

    adj = mul(adj, out_deg_inv_sqrt.view(-1, 1))
    adj = mul(adj, in_deg_inv_sqrt.view(1, -1))
    return adj


def get_norm_adj(adj, norm):
    if norm == "sym":
        return gcn_norm(adj, add_self_loops=False)
    elif norm == "row":
        return row_norm(adj)
    elif norm == "dir":
        return directed_norm(adj)
    else:
        raise ValueError(f"{norm} normalization is not supported")


def get_mask(idx, num_nodes):
    """
    Given a tensor of ids and a number of nodes, return a boolean mask of size num_nodes which is set to True at indices
    in `idx`, and to False for other indices.
    """
    mask = torch.zeros(num_nodes, dtype=torch.bool)
    mask[idx] = 1
    return mask


def get_adj(edge_index, num_nodes, graph_type="directed"):
    """
    Return the type of adjacency matrix specified by `graph_type` as sparse tensor.
    """
    if graph_type == "transpose":
        edge_index = torch.stack([edge_index[1], edge_index[0]])
    elif graph_type == "undirected":
        edge_index = torch_geometric.utils.to_undirected(edge_index)
    elif graph_type == "directed":
        pass
    else:
        raise ValueError(f"{graph_type} is not a valid graph type")

    value = torch.ones((edge_index.size(1),), device=edge_index.device)
    return SparseTensor(row=edge_index[0], col=edge_index[1], value=value, sparse_sizes=(num_nodes, num_nodes))

def generate_synthetic_directed_pa_graph(num_classes=5, num_nodes=1000, m=2, h=0.1):
    # Get compatibility matrix with given homophily
    H = np.random.rand(num_classes, num_classes)
    np.fill_diagonal(H, 0)
    H = (1 - h) * normalize(H, axis=1, norm='l1') 
    np.fill_diagonal(H, h)
    np.testing.assert_allclose(H.sum(axis=1), np.ones(num_classes) ,rtol=1e-5, atol=0)

    # Generate graph
    G = DiGraph()
    y = []
    for u in range(num_nodes):
        G.add_node(u)
        y_u = np.random.choice(range(num_classes))
        y.append(y_u)

        # Get probabilities for neighbors, proporational to in_degree and compatibility
        scores = np.array([(G.in_degree(v) + 0.01) * H[y[u], y[v]] for v in G])
        scores /= scores.sum()
        
        # Sample (at most) m neighbors according to the above probabilities
        num_edges = m if m <= G.number_of_nodes() else G.number_of_nodes()
        vs = np.random.choice(G.nodes(), size=num_edges, replace=False, p=scores)
        G.add_edges_from([(u, v) for v in vs])

    data = from_networkx(G)
    data.y = torch.Tensor(y)

    return data

  
def compute_unidirectional_edges_ratio(edge_index):
    num_directed_edges = edge_index.shape[1]
    num_undirected_edges = torch_geometric.utils.to_undirected(edge_index).shape[1]

    num_unidirectional = num_undirected_edges - num_directed_edges

    return (num_unidirectional / (num_undirected_edges / 2)) * 100
