import scipy.sparse as sp
import torch
import numpy as np
import os
import pickle as pkl
import networkx as nx

from dgl.data import AmazonCoBuyComputerDataset
from dgl.data.citation_graph import CoraGraphDataset, CiteseerGraphDataset
from dgl.data import WikiCSDataset


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo()
    indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
    )
    values = torch.Tensor(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def normalize(mx):
    """Row-normalize sparse matrix."""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def process(adj, features, normalize_adj, normalize_feats):
    if sp.isspmatrix(features):
        features = np.array(features.todense())
    if normalize_feats:
        features = normalize(features)
    features = torch.Tensor(features)
    if normalize_adj:
        adj = normalize(adj + sp.eye(adj.shape[0]))
    adj = sparse_mx_to_torch_sparse_tensor(adj)
    return adj, features

def load_synthetic_data(dataset_str, use_feats, data_path):
    object_to_idx = {}
    idx_counter = 0
    edges = []
    with open(os.path.join(data_path, "{}.edges.csv".format(dataset_str)), 'r') as f:
        all_edges = f.readlines()
    for line in all_edges:
        n1, n2 = line.rstrip().split(',')
        if n1 in object_to_idx:
            i = object_to_idx[n1]
        else:
            i = idx_counter
            object_to_idx[n1] = i
            idx_counter += 1
        if n2 in object_to_idx:
            j = object_to_idx[n2]
        else:
            j = idx_counter
            object_to_idx[n2] = j
            idx_counter += 1
        edges.append((i, j))
    adj = np.zeros((len(object_to_idx), len(object_to_idx)))
    for i, j in edges:
        adj[i, j] = 1.  # comment this line for directed adjacency matrix
        adj[j, i] = 1.
    if use_feats:
        features = sp.load_npz(os.path.join(data_path, "{}.feats.npz".format(dataset_str)))
    else:
        features = sp.eye(adj.shape[0])
    labels = np.load(os.path.join(data_path, "{}.labels.npy".format(dataset_str)))
    return sp.csr_matrix(adj), features, labels


def bin_feat(feat, bins):
    digitized = np.digitize(feat, bins)
    return digitized - digitized.min()


def load_data_airport(dataset_str, data_path, return_label=False):
    graph = pkl.load(open(os.path.join(data_path, dataset_str + '.p'), 'rb'))
    adj = nx.adjacency_matrix(graph)
    features = np.array([graph.nodes[u]['feat'] for u in graph.nodes])
    if return_label:
        label_idx = 4
        labels = features[:, label_idx]
        features = features[:, :label_idx]
        labels = bin_feat(labels, bins=[7.0/7, 8.0/7, 9.0/7])
        return sp.csr_matrix(adj), features, labels
    else:
        return sp.csr_matrix(adj), features


def split_data(labels, val_prop, test_prop, seed):
    np.random.seed(seed)
    nb_nodes = labels.shape[0]
    all_idx = np.arange(nb_nodes)
    pos_idx = labels.nonzero()[0]
    neg_idx = (1. - labels).nonzero()[0]
    np.random.shuffle(pos_idx)
    np.random.shuffle(neg_idx)
    pos_idx = pos_idx.tolist()
    neg_idx = neg_idx.tolist()
    nb_pos_neg = min(len(pos_idx), len(neg_idx))
    nb_val = round(val_prop * nb_pos_neg)
    nb_test = round(test_prop * nb_pos_neg)
    idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
                                                                                                   nb_val + nb_test:]
    idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
                                                                                                   nb_val + nb_test:]
    return idx_val_pos + idx_val_neg, idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg


def hy_load_data_from_dgl_dataset_class(args):

    dataset_name = args.dataset
    if dataset_name in ['cora', 'citeseer', 'wikics', 'amz']:
        if dataset_name == 'cora':
            args.hyp_percent = 0.15
            dataset = CoraGraphDataset()
            # 0.7/0.15/0.15
            idx_train = list(range(0, 1895))
            idx_val = list(range(1895, 2301))
            idx_test = list(range(2301, 2708))
        if dataset_name == 'citeseer':
            args.hyp_percent = 0.45
            dataset = CiteseerGraphDataset()
            # 0.7/0.15/0.15
            idx_train = list(range(0, 2327))
            idx_val = list(range(2327, 2827))
            idx_test = list(range(2827, 3327))
        if dataset_name == 'wikics':
            args.hyp_percent = 0.55
            dataset = WikiCSDataset()
            # 0.7/0.15/0.15
            idx_train = list(range(0, 8190))
            idx_val = list(range(8190, 9945))
            idx_test = list(range(9945, 11701))
        if dataset_name == 'amz':
            args.hyp_percent = 0.50
            dataset = AmazonCoBuyComputerDataset()
            # 0.7/0.15/0.15
            idx_train = list(range(0, 9626))
            idx_val = list(range(9626, 11689))
            idx_test = list(range(11689, 13752))
        g = dataset[0]
        dense = g.adjacency_matrix().to_dense()
        adj = sp.csr_matrix(dense.numpy())
        labels = g.ndata['label']
        features = sp.lil_matrix(g.ndata['feat'])

    if dataset_name in ['disease', 'airport']:
        if dataset_name == 'disease':
            args.hyp_percent = 0.85
            data_path = os.path.abspath('dataset/disease_nc')
            adj, features, labels = load_synthetic_data('disease_nc', args.use_feats, data_path)
            # 0.7/0.15/0.15
            val_prop, test_prop = 0.15, 0.15
        elif dataset_name == 'airport':
            args.hyp_percent = 0.2
            data_path = 'dataset/airport'
            adj, features, labels = load_data_airport('airport', data_path, return_label=True)
            # 0.7/0.15/0.15
            val_prop, test_prop = 0.15, 0.15
        idx_val, idx_test, idx_train = split_data(labels, val_prop, test_prop, seed=args.seed)
        labels = torch.from_numpy(labels).type(torch.int64)

    data = {'adj_train': adj, 'features': features, 'labels': labels, 'idx_train': idx_train, 'idx_val': idx_val, 'idx_test': idx_test}

    data['adj_train_norm'], data['features'] = process(
        data['adj_train'], data['features'], args.normalize_adj, args.normalize_feats
    )
    return data





