import os
import sys
import torch
import pickle as pkl
import numpy as np
import networkx as nx
import scipy.sparse as sp

from torch_geometric.datasets import WebKB
from torch_geometric.utils import to_scipy_sparse_matrix


def bin_feat(feat, bins):
    digitized = np.digitize(feat, bins)
    return digitized - digitized.min()


def split_data(labels, val_prop, test_prop, seed):
    np.random.seed(seed)
    nb_nodes = labels.shape[0]
    all_idx = np.arange(nb_nodes)
    pos_idx = labels.nonzero()[0]
    neg_idx = (1. - labels).nonzero()[0]
    np.random.shuffle(pos_idx)
    np.random.shuffle(neg_idx)
    pos_idx = pos_idx.tolist()
    neg_idx = neg_idx.tolist()
    nb_pos_neg = min(len(pos_idx), len(neg_idx))
    nb_val = round(val_prop * nb_pos_neg)
    nb_test = round(test_prop * nb_pos_neg)
    idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
                                                                                                   nb_val + nb_test:]
    idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
                                                                                                   nb_val + nb_test:]
    return idx_val_pos + idx_val_neg, idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg


def load_raw_data(data_path):
    graph = pkl.load(open(os.path.join(data_path, 'airport/airport.p'), 'rb'))
    adj = nx.adjacency_matrix(graph)
    features = np.array([graph.nodes[u]['feat'] for u in graph.nodes()])
    label_idx = 4
    labels = features[:, label_idx]
    features = features[:, :label_idx]
    labels = bin_feat(labels, bins=[7.0/7, 8.0/7, 9.0/7])
    return sp.csr_matrix(adj), features, labels


def load_nc_data(args, data_path, split_seed=None):
    adj, features, labels = load_raw_data(data_path)

    idx_val, idx_test, idx_train = split_data(labels, args.val_prop, args.test_prop, seed=split_seed)
    labels = torch.LongTensor(labels)
    data = {
        'adj_train': adj,
        'features': features,
        'labels': labels, 
        'idx_train': idx_train,
        'idx_val': idx_val,
        'idx_test': idx_test
    }
    return data


def build_distance(G):
    length = dict(nx.all_pairs_shortest_path_length(G))
    R = np.array([[length.get(m, {}).get(n, 0) for m in G.nodes] for n in G.nodes], dtype=np.int32)
    return R


def load_md_data(args, data_path, split_seed=None):
    adj, features, labels = load_raw_data(data_path)
    G = nx.from_numpy_array(adj.toarray())
    labels = build_distance(G)
    features = sp.eye(adj.shape[0]) 
    rand_feature = np.random.uniform(low=-0.02, high=0.02, size=(adj.shape[0],adj.shape[0]))
    features = features + sp.csr_matrix(rand_feature)
    data = {
        'adj_train': adj,
        'features': features,
        'labels': labels, 
        'G': G
    }
    return data

