import torch
import numpy as np
import networkx as nx
import scipy.sparse as sp

from torch_geometric.datasets import WebKB
from torch_geometric.utils import to_scipy_sparse_matrix


def load_raw_data(data_path):
    dataset = WebKB(data_path, 'wisconsin', transform=None, pre_transform=None)

    adj = to_scipy_sparse_matrix(dataset.edge_index, num_nodes = dataset.x.shape[0])
    features = dataset.x.numpy()
    labels = dataset.y.numpy()
    return adj, features, labels


def load_nc_data(args, data_path, split_seed=None):
    adj, features, labels = load_raw_data(data_path)
    print(adj.shape)
    
    np.random.seed(seed=split_seed)
    
    N = features.shape[0]
    test_num = round(args.test_prop * N)
    val_num = round(args.val_prop * N)
    indices = np.random.permutation(np.arange(N))
    idx_test = indices[:test_num].tolist()
    idx_val = indices[test_num:test_num + val_num].tolist()
    idx_train = indices[test_num + val_num:].tolist()
    
    labels = torch.LongTensor(labels)
    data = {
        'adj_train': adj,
        'features': features,
        'labels': labels, 
        'idx_train': idx_train,
        'idx_val': idx_val,
        'idx_test': idx_test
    }
    return data


def build_distance(G):
    length = dict(nx.all_pairs_shortest_path_length(G))
    R = np.array([[length.get(m, {}).get(n, 0) for m in G.nodes] for n in G.nodes], dtype=np.int32)
    return R


def load_md_data(args, data_path, split_seed=None):
    adj, features, labels = load_raw_data(data_path)
    G = nx.from_numpy_array(adj.toarray())
    labels = build_distance(G)
    features = sp.eye(adj.shape[0]) 
    rand_feature = np.random.uniform(low=-0.02, high=0.02, size=(adj.shape[0],adj.shape[0]))
    features = features + sp.csr_matrix(rand_feature)
    data = {
        'adj_train': adj,
        'features': features,
        'labels': labels, 
        'G': G
    }
    return data

