import os
import sys
import torch
import pickle as pkl
import numpy as np
import networkx as nx
import scipy.sparse as sp

from torch_geometric.datasets import WebKB
from torch_geometric.utils import to_scipy_sparse_matrix


def parse_index_file(filename):
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def load_raw_data(data_path):
    dataset_str = 'citeseer'
    data_path = os.path.join(data_path, dataset_str)
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open(os.path.join(data_path, "ind.{}.{}".format(dataset_str, names[i])), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file(os.path.join(data_path, "ind.{}.test.index".format(dataset_str)))
    test_idx_range = np.sort(test_idx_reorder)

    test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)
    tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
    tx_extended[test_idx_range - min(test_idx_range), :] = tx
    tx = tx_extended
    ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
    ty_extended[test_idx_range - min(test_idx_range), :] = ty
    ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]
    labels = np.argmax(labels, 1)


    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
    rand_feature = np.random.uniform(low=-0.01, high=0.01, size=(adj.shape[0],features.shape[1]))
    features = features + sp.csr_matrix(rand_feature)

    return adj, features, labels, test_idx_range, len(y)


def load_nc_data(args, data_path, split_seed=None):
    adj, features, labels, test_idx_range, len_y = load_raw_data(data_path)
    
    idx_test = test_idx_range.tolist()
    idx_train = list(range(len_y))
    idx_val = range(len_y, len_y + 500)
    
    labels = torch.LongTensor(labels)
    data = {
        'adj_train': adj,
        'features': features,
        'labels': labels, 
        'idx_train': idx_train,
        'idx_val': idx_val,
        'idx_test': idx_test
    }
    return data


def build_distance(G):
    length = dict(nx.all_pairs_shortest_path_length(G))
    R = np.array([[length.get(m, {}).get(n, 0) for m in G.nodes] for n in G.nodes], dtype=np.int32)
    return R


def load_md_data(args, data_path, split_seed=None):
    adj, features, labels, _, _ = load_raw_data(data_path)
    G = nx.from_numpy_array(adj.toarray())
    labels = build_distance(G)
    features = sp.eye(adj.shape[0]) 
    rand_feature = np.random.uniform(low=-0.02, high=0.02, size=(adj.shape[0],adj.shape[0]))
    features = features + sp.csr_matrix(rand_feature)
    data = {
        'adj_train': adj,
        'features': features,
        'labels': labels, 
        'G': G
    }
    return data

