import scipy.sparse as sp
import numpy as np
import os
import pickle as pkl
import networkx as nx
import dgl
import torch
from dgl.data.citation_graph import CoraGraphDataset, CiteseerGraphDataset
from dgl.data import WikiCSDataset
from dgl.data import AmazonCoBuyComputerDataset


def load_synthetic_data(dataset_str, use_feats, data_path):
    object_to_idx = {}
    idx_counter = 0
    edges = []
    with open(os.path.join(data_path, "{}.edges.csv".format(dataset_str)), 'r') as f:
        all_edges = f.readlines()
    for line in all_edges:
        n1, n2 = line.rstrip().split(',')
        if n1 in object_to_idx:
            i = object_to_idx[n1]
        else:
            i = idx_counter
            object_to_idx[n1] = i
            idx_counter += 1
        if n2 in object_to_idx:
            j = object_to_idx[n2]
        else:
            j = idx_counter
            object_to_idx[n2] = j
            idx_counter += 1
        edges.append((i, j))
    adj = np.zeros((len(object_to_idx), len(object_to_idx)))
    for i, j in edges:
        adj[i, j] = 1.  # comment this line for directed adjacency matrix
        adj[j, i] = 1.
    if use_feats:
        features = sp.load_npz(os.path.join(data_path, "{}.feats.npz".format(dataset_str)))
    else:
        features = sp.eye(adj.shape[0])
    labels = np.load(os.path.join(data_path, "{}.labels.npy".format(dataset_str)))
    return sp.csr_matrix(adj), features, labels


def bin_feat(feat, bins):
    digitized = np.digitize(feat, bins)
    return digitized - digitized.min()


def split_data(labels, val_prop, test_prop, seed):
    np.random.seed(seed)
    nb_nodes = labels.shape[0]
    all_idx = np.arange(nb_nodes)
    pos_idx = labels.nonzero()[0]
    neg_idx = (1. - labels).nonzero()[0]
    np.random.shuffle(pos_idx)
    np.random.shuffle(neg_idx)
    pos_idx = pos_idx.tolist()
    neg_idx = neg_idx.tolist()
    nb_pos_neg = min(len(pos_idx), len(neg_idx))
    nb_val = round(val_prop * nb_pos_neg)
    nb_test = round(test_prop * nb_pos_neg)
    idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
                                                                                                   nb_val + nb_test:]
    idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
                                                                                                   nb_val + nb_test:]
    return idx_val_pos + idx_val_neg, idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg


def load_data_airport(dataset_str, data_path, return_label=False):
    graph = pkl.load(open(os.path.join(data_path, dataset_str + '.p'), 'rb'))
    adj = nx.adjacency_matrix(graph)
    features = np.array([graph.nodes[u]['feat'] for u in graph.nodes])
    if return_label:
        label_idx = 4
        labels = features[:, label_idx]
        features = features[:, :label_idx]
        labels = bin_feat(labels, bins=[7.0/7, 8.0/7, 9.0/7])
        return sp.csr_matrix(adj), features, labels
    else:
        return sp.csr_matrix(adj), features


def eu_load_data_from_dgl_dataset_class(args):
    dataset_name = args.dataset
    if dataset_name in ['cora', 'pubmed', 'citeseer', 'wikics', 'coauthor', 'amz']:
        if dataset_name == 'cora':
            dataset = CoraGraphDataset()
            # 0.7/0.15/0.15
            for i in range(len(dataset[0].nodes())):
                if i < 1895:
                    dataset[0].ndata['train_mask'][i] = True
                    dataset[0].ndata['val_mask'][i] = False
                    dataset[0].ndata['test_mask'][i] = False
                elif i < 2301:
                    dataset[0].ndata['train_mask'][i] = False
                    dataset[0].ndata['val_mask'][i] = True
                    dataset[0].ndata['test_mask'][i] = False
                else:
                    dataset[0].ndata['train_mask'][i] = False
                    dataset[0].ndata['val_mask'][i] = False
                    dataset[0].ndata['test_mask'][i] = True
        if dataset_name == 'citeseer':
            dataset = CiteseerGraphDataset()
            # 0.7/0.15/0.15
            for i in range(len(dataset[0].nodes())):
                if i < 2327:
                    dataset[0].ndata['train_mask'][i] = True
                    dataset[0].ndata['val_mask'][i] = False
                    dataset[0].ndata['test_mask'][i] = False
                elif i < 2827:
                    dataset[0].ndata['train_mask'][i] = False
                    dataset[0].ndata['val_mask'][i] = True
                    dataset[0].ndata['test_mask'][i] = False
                else:
                    dataset[0].ndata['train_mask'][i] = False
                    dataset[0].ndata['val_mask'][i] = False
                    dataset[0].ndata['test_mask'][i] = True
        if dataset_name == 'amz':
            dataset = AmazonCoBuyComputerDataset()
            dataset[0].ndata['train_mask'] = torch.BoolTensor(13752)
            dataset[0].ndata['val_mask'] = torch.BoolTensor(13752)
            dataset[0].ndata['test_mask'] = torch.BoolTensor(13752)
            # 0.7/0.15/0.15
            for i in range(len(dataset[0].nodes())):
                if i < 9626:
                    dataset[0].ndata['train_mask'][i] = True
                    dataset[0].ndata['val_mask'][i] = False
                    dataset[0].ndata['test_mask'][i] = False
                elif i < 11689:
                    dataset[0].ndata['train_mask'][i] = False
                    dataset[0].ndata['val_mask'][i] = True
                    dataset[0].ndata['test_mask'][i] = False
                else:
                    dataset[0].ndata['train_mask'][i] = False
                    dataset[0].ndata['val_mask'][i] = False
                    dataset[0].ndata['test_mask'][i] = True
        if dataset_name == 'wikics':
            dataset = WikiCSDataset()
            graph = dataset[0]
            N = len(graph.nodes())
            train_mask = np.zeros(N, dtype=bool)
            val_mask = np.zeros(N, dtype=bool)
            test_mask = np.zeros(N, dtype=bool)
            graph.ndata['train_mask'] = torch.from_numpy(train_mask)
            graph.ndata['val_mask'] = torch.from_numpy(val_mask)
            graph.ndata['test_mask'] = torch.from_numpy(test_mask)
            # 0.7/0.15/0.15
            for i in range(len(graph.nodes())):
                if i < 8190:
                    graph.ndata['train_mask'][i] = True
                    graph.ndata['val_mask'][i] = False
                    graph.ndata['test_mask'][i] = False
                elif i < 9945:
                    graph.ndata['train_mask'][i] = False
                    graph.ndata['val_mask'][i] = True
                    graph.ndata['test_mask'][i] = False
                else:
                    graph.ndata['train_mask'][i] = False
                    graph.ndata['val_mask'][i] = False
                    graph.ndata['test_mask'][i] = True
        graph = dataset[0]

    if dataset_name in ['disease', 'airport']:
        if dataset_name == 'disease':
            data_path = 'dataset/disease_nc'
            adj, features, labels = load_synthetic_data('disease_nc', args.use_feats, data_path)
            # 0.7/0.15/0.15
            val_prop, test_prop = 0.15, 0.15
        elif dataset_name == 'airport':
            data_path = 'dataset/airport'
            adj, features, labels = load_data_airport('airport', data_path, return_label=True)
            # 0.7/0.15/0.15
            val_prop, test_prop = 0.15, 0.15
        idx_val, idx_test, idx_train = split_data(labels, val_prop, test_prop, seed=args.seed)

        adj_coo = adj.tocoo()
        graph = dgl.graph((adj_coo.row, adj_coo.col))
        if dataset_name == 'disease':
            features = np.array(features.todense())
        features = torch.from_numpy(features)
        graph.ndata['feat'] = features
        labels = torch.from_numpy(labels).type(torch.int64)
        graph.ndata['label'] = labels

        N = len(graph.nodes())
        train_mask = np.zeros(N, dtype=bool)
        val_mask = np.zeros(N, dtype=bool)
        test_mask = np.zeros(N, dtype=bool)
        train_mask[idx_train] = True
        val_mask[idx_val] = True
        test_mask[idx_test] = True
        graph.ndata['train_mask'] = torch.from_numpy(train_mask)
        graph.ndata['val_mask'] = torch.from_numpy(val_mask)
        graph.ndata['test_mask'] = torch.from_numpy(test_mask)
    return graph
