#!/usr/bin/env python

import json
import numpy as np
import scipy.sparse as sp
from time import time
import os, sys
import pickle as pkl
import networkx as nx
from networkx.readwrite import json_graph
import metis


def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)


def load_gcn_data(root, dataset_str, normalization="gcn", partition_size=1):
    npz_file = "{}/data/{}_{}_{}.npz".format(
        root, dataset_str, normalization, partition_size
    )
    if os.path.exists(npz_file):
        start_time = time()
        print("Found preprocessed dataset {}, loading...".format(npz_file))
        data = np.load(npz_file)
        num_data = data["num_data"]
        labels = data["labels"]
        train_data = data["train_data"]
        val_data = data["val_data"]
        test_data = data["test_data"]
        parts = data["parts"]
        train_adj = sp.csr_matrix(
            (
                data["train_adj_data"],
                data["train_adj_indices"],
                data["train_adj_indptr"],
            ),
            shape=data["train_adj_shape"],
        )
        full_adj = sp.csr_matrix(
            (data["full_adj_data"], data["full_adj_indices"], data["full_adj_indptr"]),
            shape=data["full_adj_shape"],
        )
        part_adj = sp.csr_matrix(
            (data["part_adj_data"], data["part_adj_indices"], data["part_adj_indptr"]),
            shape=data["part_adj_shape"],
        )
        feats = sp.csr_matrix(
            (data["feats_data"], data["feats_indices"], data["feats_indptr"]),
            shape=data["feats_shape"],
        )
        train_feats = sp.csr_matrix(
            (
                data["train_feats_data"],
                data["train_feats_indices"],
                data["train_feats_indptr"],
            ),
            shape=data["train_feats_shape"],
        )
        test_feats = sp.csr_matrix(
            (
                data["test_feats_data"],
                data["test_feats_indices"],
                data["test_feats_indptr"],
            ),
            shape=data["test_feats_shape"],
        )
        print("Finished in {} seconds.".format(time() - start_time))
    else:
        """Load data."""
        names = ["x", "y", "tx", "ty", "allx", "ally", "graph"]
        objects = []
        for i in range(len(names)):
            with open(
                "{}/data/ind.{}.{}".format(root, dataset_str, names[i]), "rb"
            ) as f:
                if sys.version_info > (3, 0):
                    objects.append(pkl.load(f, encoding="latin1"))
                else:
                    objects.append(pkl.load(f))

        x, y, tx, ty, allx, ally, graph = tuple(objects)

        if dataset_str != "nell":
            test_idx_reorder = parse_index_file(
                "{}/data/ind.{}.test.index".format(root, dataset_str)
            )
            test_idx_range = np.sort(test_idx_reorder)

            if dataset_str == "citeseer":
                # Fix citeseer dataset (there are some isolated nodes in the graph)
                # Find isolated nodes, add them as zero-vecs into the right position
                test_idx_range_full = range(
                    min(test_idx_reorder), max(test_idx_reorder) + 1
                )
                tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
                tx_extended[test_idx_range - min(test_idx_range), :] = tx
                tx = tx_extended
                ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
                ty_extended[test_idx_range - min(test_idx_range), :] = ty
                ty = ty_extended

            features = sp.vstack((allx, tx)).tolil()
            features[test_idx_reorder, :] = features[test_idx_range, :]
            G = nx.from_dict_of_lists(graph)
            adj = nx.adjacency_matrix(G)

            labels = np.vstack((ally, ty))
            labels[test_idx_reorder, :] = labels[test_idx_range, :]

            idx_test = test_idx_range.tolist()
            idx_train = range(len(y))
            idx_val = range(len(y), len(y) + 500)

            train_mask = sample_mask(idx_train, labels.shape[0])
            val_mask = sample_mask(idx_val, labels.shape[0])
            test_mask = sample_mask(idx_test, labels.shape[0])

            y_train = np.zeros(labels.shape)
            y_val = np.zeros(labels.shape)
            y_test = np.zeros(labels.shape)
            y_train[train_mask, :] = labels[train_mask, :]
            y_val[val_mask, :] = labels[val_mask, :]
            y_test[test_mask, :] = labels[test_mask, :]
        else:
            test_idx_reorder = parse_index_file(
                "{}/data/ind.{}.test.index".format(root, dataset_str)
            )
            features = allx.tocsr()
            G = nx.from_dict_of_lists(graph)
            adj = nx.adjacency_matrix(G)
            labels = ally
            idx_test = test_idx_reorder
            idx_train = range(len(y))
            idx_val = range(len(y), len(y) + 969)
            train_mask = sample_mask(idx_train, labels.shape[0])
            val_mask = sample_mask(idx_val, labels.shape[0])
            test_mask = sample_mask(idx_test, labels.shape[0])
            y_train = np.zeros(labels.shape)
            y_val = np.zeros(labels.shape)
            y_test = np.zeros(labels.shape)
            y_train[train_mask, :] = labels[train_mask, :]
            y_val[val_mask, :] = labels[val_mask, :]
            y_test[test_mask, :] = labels[test_mask, :]

        # num_data, (v, coords), feats, labels, train_d, val_d, test_d
        num_data = features.shape[0]

        def _normalize_adj(adj):
            rowsum = np.array(adj.sum(1)).flatten()
            d_inv = 1.0 / (rowsum + 1e-20)
            d_mat_inv = sp.diags(d_inv, 0)
            adj = d_mat_inv.dot(adj).tocoo()
            coords = np.array((adj.row, adj.col)).astype(np.int32)
            return adj.data.astype(np.float32), coords

        def gcn_normalize_adj(adj):
            adj = adj + sp.eye(adj.shape[0])
            rowsum = np.array(adj.sum(1)) + 1e-20
            d_inv_sqrt = np.power(rowsum, -0.5).flatten()
            d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
            d_mat_inv_sqrt = sp.diags(d_inv_sqrt, 0)
            adj = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt)
            adj = adj.tocoo()
            coords = np.array((adj.row, adj.col)).astype(np.int32)
            return adj.data.astype(np.float32), coords

        # Normalize features
        rowsum = np.array(features.sum(1)) + 1e-9
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.0
        r_mat_inv = sp.diags(r_inv, 0)
        features = r_mat_inv.dot(features)

        # graph partitioning
        # (future work) try weighted edges
        if partition_size > 1:
            print("do graph partitioning, size {}".format(partition_size))
            tmp_time = time()
            _, nd_group = metis.part_graph(G, partition_size)
            print("metis finished in {} seconds.".format(time() - tmp_time))

            tmp_time = time()
            # part_adj = [[] for _ in range(num_node)]
            part_row = []
            part_col = []
            part_data = []
            parts = [[] for _ in range(partition_size)]
            for nd_idx in range(num_data):
                gp_idx = nd_group[nd_idx]
                parts[gp_idx].append(nd_idx)
                for nb_idx in adj[nd_idx].indices:
                    if nd_group[nb_idx] == gp_idx:
                        # part_adj[nd_idx].append(nb_idx)
                        part_row.append(nd_idx)
                        part_col.append(nb_idx)
                        part_data.append(1)
            part_adj = sp.coo_matrix((part_data, (part_row, part_col))).tocsr()
            parts = np.array([np.array(pt) for pt in parts])
            print(len(adj.data))
            print(len(part_adj.data))
            print(adj[0:5])
            print("---")
            print(part_adj[0:5])
            print("adj and parts constructed in {} seconds.".format(time() - tmp_time))
            print(
                "{} / {} edges are remained.".format(
                    len(part_adj.data) // 2, len(adj.data) // 2
                )
            )

            """
            # construct group nodes list
            parts = [[] for _ in range(partition_size)]
            for idx, gp in enumerate(nd_group):
                parts[gp].append(idx)
            part_adj = adj.tolil(copy=True)

            for part in parts:
                # construct diagonal matrix
                part_mask = np.zeros(num_data)
                part_mask[part] = 1
                part_adj[part] = part_adj[part].multiply(part_mask)
            """
        else:
            part_adj = adj
            parts = np.array(np.array([]))

        if normalization == "gcn":
            full_v, full_coords = gcn_normalize_adj(adj)
            part_v, part_coords = gcn_normalize_adj(part_adj)
        else:
            full_v, full_coords = _normalize_adj(adj)
            part_v, part_coords = _normalize_adj(part_adj)
        full_v = full_v.astype(np.float32)
        part_v = part_v.astype(np.float32)
        full_coords = full_coords.astype(np.int32)
        part_coords = part_coords.astype(np.int32)
        train_v, train_coords = full_v, full_coords
        labels = (y_train + y_val + y_test).astype(np.float32)
        train_data = np.nonzero(train_mask)[0].astype(np.int32)
        val_data = np.nonzero(val_mask)[0].astype(np.int32)
        test_data = np.nonzero(test_mask)[0].astype(np.int32)

        feats = (features.data, features.indices, features.indptr, features.shape)

        def _get_adj(data, coords):
            adj = sp.csr_matrix(
                (data, (coords[0, :], coords[1, :])), shape=(num_data, num_data)
            )
            return adj

        part_adj = _get_adj(part_v, part_coords)
        train_adj = _get_adj(train_v, train_coords)
        full_adj = _get_adj(full_v, full_coords)
        feats = sp.csr_matrix(
            (feats[0], feats[1], feats[2]), shape=feats[-1], dtype=np.float32
        )

        train_feats = train_adj.dot(feats)
        # train_feats = feats
        test_feats = full_adj.dot(feats)
        # test_feats = feats

        with open(npz_file, "wb") as fwrite:
            np.savez(
                fwrite,
                num_data=num_data,
                train_adj_data=train_adj.data,
                train_adj_indices=train_adj.indices,
                train_adj_indptr=train_adj.indptr,
                train_adj_shape=train_adj.shape,
                full_adj_data=full_adj.data,
                full_adj_indices=full_adj.indices,
                full_adj_indptr=full_adj.indptr,
                full_adj_shape=full_adj.shape,
                part_adj_data=part_adj.data,
                part_adj_indices=part_adj.indices,
                part_adj_indptr=part_adj.indptr,
                part_adj_shape=part_adj.shape,
                feats_data=feats.data,
                feats_indices=feats.indices,
                feats_indptr=feats.indptr,
                feats_shape=feats.shape,
                train_feats_data=train_feats.data,
                train_feats_indices=train_feats.indices,
                train_feats_indptr=train_feats.indptr,
                train_feats_shape=train_feats.shape,
                test_feats_data=test_feats.data,
                test_feats_indices=test_feats.indices,
                test_feats_indptr=test_feats.indptr,
                test_feats_shape=test_feats.shape,
                labels=labels,
                train_data=train_data,
                val_data=val_data,
                test_data=test_data,
                parts=parts,
            )

    return (
        num_data,
        train_adj,
        full_adj,
        part_adj,
        part_adj,
        feats,
        train_feats,
        test_feats,
        labels,
        train_data,
        val_data,
        test_data,
        parts,
        parts,
    )


def subsample_edges(edges, num_data, max_degree):
    edges = np.array(edges, dtype=np.int32)
    np.random.shuffle(edges)
    degree = np.zeros(num_data, dtype=np.int32)

    new_edges = []
    for e in edges:
        if degree[e[0]] < max_degree and degree[e[1]] < max_degree:
            new_edges.append((e[0], e[1]))
            degree[e[0]] += 1
            degree[e[1]] += 1
    return new_edges


def load_graphsage_data(
    prefix,
    normalize=True,
    partition_list=[1],
    max_degree=-1,
    has_feats=True,
    class_list=False,
):
    version_info = list(map(int, nx.__version__.split(".")))
    major = version_info[0]
    minor = version_info[1]
    assert (major <= 1) and (
        minor <= 11
    ), "networkx major version must be <= 1.11 in order to load graphsage data"

    # Save normalized version
    if max_degree == -1:
        npz_file = "{}_{}.npz".format(prefix, ",".join(map(str, partition_list)))
    else:
        npz_file = "{}_deg{}_{}.npz".format(
            prefix, max_degree, ",".join(map(str, partition_list))
        )

    if os.path.exists(npz_file):
        start_time = time()
        print("Found preprocessed dataset {}, loading...".format(npz_file))
        data = np.load(npz_file, allow_pickle=True)
        num_data = data["num_data"]
        labels = data["labels"]
        train_data = data["train_data"]
        val_data = data["val_data"]
        test_data = data["test_data"]
        # train_parts  = data['train_parts']
        full_parts = data["full_parts"]
        train_adj = sp.csr_matrix(
            (
                data["train_adj_data"],
                data["train_adj_indices"],
                data["train_adj_indptr"],
            ),
            shape=data["train_adj_shape"],
        )
        full_adj = sp.csr_matrix(
            (data["full_adj_data"], data["full_adj_indices"], data["full_adj_indptr"]),
            shape=data["full_adj_shape"],
        )
        # train_part_adj = sp.csr_matrix((data['train_part_adj_data'], data['train_part_adj_indices'], data['train_part_adj_indptr']), shape=data['train_part_adj_shape'])
        full_part_adj = sp.csr_matrix(
            (
                data["full_part_adj_data"],
                data["full_part_adj_indices"],
                data["full_part_adj_indptr"],
            ),
            shape=data["full_part_adj_shape"],
        )

        if has_feats:
            feats = data["feats"]
            train_feats = data["train_feats"]
            test_feats = data["test_feats"]
        else:
            # NOTE: Does not handle noPP now
            feats = sp.eye(num_data).astype(np.float32).tocsr()
            train_feats = train_adj
            test_feats = full_adj

        train_part_adj_list = []
        train_parts_list = []
        for psize in partition_list:
            SUFFIX = "_{}".format(psize)
            train_part_adj = sp.csr_matrix(
                (
                    data["train_part_adj_data" + SUFFIX],
                    data["train_part_adj_indices" + SUFFIX],
                    data["train_part_adj_indptr" + SUFFIX],
                ),
                shape=data["train_part_adj_shape" + SUFFIX],
            )
            train_parts = data["train_parts" + SUFFIX]

            train_part_adj_list.append(train_part_adj)
            train_parts_list.append(train_parts)

        print("Finished in {} seconds.".format(time() - start_time))
    else:
        print("Loading data...")
        start_time = time()

        G_data = json.load(open(prefix + "-G.json"))
        G = json_graph.node_link_graph(G_data)

        id_map = json.load(open(prefix + "-id_map.json"))
        if list(id_map.keys())[0].isdigit():
            conversion = lambda n: int(n)
        else:
            conversion = lambda n: n
        id_map = {conversion(k): int(v) for k, v in id_map.items()}

        walks = []
        class_map = json.load(open(prefix + "-class_map.json"))
        if isinstance(list(class_map.values())[0], list):
            lab_conversion = lambda n: n
        else:
            lab_conversion = lambda n: int(n)

        class_map = {conversion(k): lab_conversion(v) for k, v in class_map.items()}

        """
        ## Remove all nodes that do not have val/test annotations
        ## (necessary because of networkx weirdness with the Reddit data)
        broken_count = 0
        to_remove = []
        for node in G.nodes():
            #if not id_map.has_key(node):
            if not node in id_map:
            #if not G.node[node].has_key('val') or not G.node[node].has_key('test'):
                to_remove.append(node)
                broken_count += 1
        for node in to_remove:
            G.remove_node(node)
        print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
        """

        # Construct adjacency matrix
        print(
            "Loaded data ({} seconds).. now preprocessing..".format(time() - start_time)
        )
        start_time = time()

        edges = []
        for edge in G.edges():
            if edge[0] in id_map and edge[1] in id_map:
                edges.append((id_map[edge[0]], id_map[edge[1]]))
        print("{} edges".format(len(edges)))
        num_data = len(id_map)

        if has_feats:
            feats = np.load(prefix + "-feats.npy").astype(np.float32)
        else:
            feats = sp.eye(num_data).astype(np.float32).tocsr()

        if max_degree != -1:
            print("Subsampling edges...")
            edges = subsample_edges(edges, num_data, max_degree)

        val_data = np.array(
            [id_map[n] for n in G.nodes() if G.node[n]["val"]], dtype=np.int32
        )
        test_data = np.array(
            [id_map[n] for n in G.nodes() if G.node[n]["test"]], dtype=np.int32
        )
        is_train = np.ones((num_data), dtype=np.bool)
        is_train[val_data] = False
        is_train[test_data] = False
        train_data = np.array(
            [n for n in range(num_data) if is_train[n]], dtype=np.int32
        )
        full_data = np.arange(num_data)

        train_edges = [(e[0], e[1]) for e in edges if is_train[e[0]] and is_train[e[1]]]
        edges = np.array(edges, dtype=np.int32)
        train_edges = np.array(train_edges, dtype=np.int32)

        # Process labels
        if class_list:
            num_classes = 0
            for k, v in class_map.items():
                num_classes = max(num_classes, max(v) + 1)
            labels = np.zeros((num_data, num_classes), dtype=np.float32)
            for k, v in class_map.items():
                labels[id_map[k], v] = 1
        elif isinstance(list(class_map.values())[0], list):
            num_classes = len(list(class_map.values())[0])
            labels = np.zeros((num_data, num_classes), dtype=np.float32)
            for k in class_map.keys():
                labels[id_map[k], :] = np.array(class_map[k])
        else:
            num_classes = len(set(class_map.values()))
            labels = np.zeros((num_data, num_classes), dtype=np.float32)
            for k in class_map.keys():
                labels[id_map[k], class_map[k]] = 1
        print("number of classes", num_classes)

        if normalize:
            from sklearn.preprocessing import StandardScaler

            train_ids = np.array(
                [
                    id_map[n]
                    for n in G.nodes()
                    if not G.node[n]["val"] and not G.node[n]["test"]
                ]
            )
            train_feats = feats[train_ids]
            scaler = StandardScaler()
            scaler.fit(train_feats)
            feats = scaler.transform(feats)

        def _my_normalize_adj(adj):
            rowsum = np.array(adj.sum(1)).flatten()
            d_inv = 1.0 / (rowsum + 1e-20)
            d_mat_inv = sp.diags(d_inv, 0)
            adj = d_mat_inv.dot(adj)
            return adj

        def _new_normalize_adj(edges, mask):
            adj = sp.csr_matrix(
                (
                    np.ones((edges.shape[0]), dtype=np.float32),
                    (edges[:, 0], edges[:, 1]),
                ),
                shape=(num_data, num_data),
            )
            adj += adj.transpose()

            adj = _my_normalize_adj(adj)
            return adj

        train_adj = _new_normalize_adj(train_edges, train_data)
        full_adj = _new_normalize_adj(edges, full_data)
        train_feats = train_adj.dot(feats)
        test_feats = full_adj.dot(feats)

        print("train nonzeros", len(train_adj.data))
        # graph partitioning
        # (future work) try weighted edges

        # NOTE: train_adj contains self-edges!
        num_train_data = train_data.size
        train_adj_lists = [[] for _ in range(num_train_data)]
        train_adj_lil = train_adj[train_data, :][:, train_data].tolil()
        train_ord_map = dict()

        full_adj_lists = [[] for _ in range(num_data)]
        full_adj_lil = full_adj.tolil()
        full_ord_map = dict()
        for i in range(num_train_data):
            train_adj_lists[i] = train_adj_lil[i].rows[0]
            train_ord_map[train_data[i]] = i
        for i in range(num_data):
            full_adj_lists[i] = full_adj_lil[i].rows[0]
            full_ord_map[full_data[i]] = i

        partition_size = partition_list[0]
        if partition_list[0] > 1:

            def get_adj_parts(
                num_data,
                orig_adj,
                nd_data,
                ord_map,
                nd_group,
                part_size,
                num_whole_data,
            ):
                part_row = []
                part_col = []
                part_data = []
                parts = [[] for _ in range(part_size)]
                for nd_idx in range(num_data):
                    gp_idx = nd_group[nd_idx]
                    nd_real_idx = nd_data[nd_idx]
                    parts[gp_idx].append(nd_real_idx)
                    for nb_real_idx in orig_adj[nd_real_idx].indices:
                        nb_idx = ord_map[nb_real_idx]
                        if nd_group[nb_idx] == gp_idx:
                            part_row.append(nd_real_idx)
                            part_col.append(nb_real_idx)
                            part_data.append(1)
                part_data.append(0)
                part_row.append(num_whole_data - 1)
                part_col.append(num_whole_data - 1)
                part_adj = sp.coo_matrix((part_data, (part_row, part_col))).tocsr()
                parts_np = np.array([np.array(pt) for pt in parts])

                return part_adj, parts_np

            print("do graph partitioning, size {}".format(partition_list))

            train_part_adj_list = []
            train_parts_list = []
            for psize in partition_list:
                tmp_time = time()
                print("run metis with partition size {}".format(psize))
                _, train_nd_group = metis.part_graph(train_adj_lists, psize)
                print("metis finished in {} seconds.".format(time() - tmp_time))
                print("train group {}".format(len(train_nd_group)))
                train_part_adj, train_parts = get_adj_parts(
                    num_train_data,
                    train_adj,
                    train_data,
                    train_ord_map,
                    train_nd_group,
                    psize,
                    num_data,
                )
                train_part_adj = _my_normalize_adj(train_part_adj)

                train_part_adj_list.append(train_part_adj)
                train_parts_list.append(train_parts)
                print(
                    "{} / {} edges are remained (train may includes self-edge).".format(
                        len(train_part_adj.data) // 2, len(train_adj.data) // 2
                    )
                )

            print("-------finish train partitions")
            tmp_time = time()
            # NOTE: should we consider the same partition size?
            full_partition_size = min(30, partition_list[0])
            print("do graph partitioning, size {}".format(full_partition_size))
            _, full_nd_group = metis.part_graph(full_adj_lists, full_partition_size)
            print("metis finished in {} seconds.".format(time() - tmp_time))
            print("full group {}".format(len(full_nd_group)))
            full_part_adj, full_parts = get_adj_parts(
                num_data,
                full_adj,
                full_data,
                full_ord_map,
                full_nd_group,
                full_partition_size,
                num_data,
            )
            full_part_adj = _my_normalize_adj(full_part_adj)

            print(
                "{} / {} edges are remained (full may inclues self-edge).".format(
                    len(full_part_adj.data) // 2, len(full_adj.data) // 2
                )
            )
            print("train_data {} data {}".format(num_train_data, num_data))

        else:
            train_part_adj = train_adj
            train_parts = np.array(np.array([]))
            full_part_adj = full_adj
            full_parts = np.array(np.array([]))
        print("Done. {} seconds.".format(time() - start_time))

        train_part_adj_data_dict = dict()
        train_part_adj_indices_dict = dict()
        train_part_adj_indptr_dict = dict()
        train_part_adj_shape_dict = dict()
        train_parts_dict = dict()
        for train_part_adj, train_parts, psize in zip(
            train_part_adj_list, train_parts_list, partition_list
        ):
            SUFFIX = "_{}".format(psize)
            train_part_adj_data_dict[
                "train_part_adj_data" + SUFFIX
            ] = train_part_adj.data
            train_part_adj_indices_dict[
                "train_part_adj_indices" + SUFFIX
            ] = train_part_adj.indices
            train_part_adj_indptr_dict[
                "train_part_adj_indptr" + SUFFIX
            ] = train_part_adj.indptr
            train_part_adj_shape_dict[
                "train_part_adj_shape" + SUFFIX
            ] = train_part_adj.shape
            train_parts_dict["train_parts" + SUFFIX] = train_parts

        with open(npz_file, "wb") as fwrite:
            print("Saving {} edges".format(full_adj.nnz))
            if has_feats:
                np.savez(
                    fwrite,
                    num_data=num_data,
                    train_adj_data=train_adj.data,
                    train_adj_indices=train_adj.indices,
                    train_adj_indptr=train_adj.indptr,
                    train_adj_shape=train_adj.shape,
                    full_adj_data=full_adj.data,
                    full_adj_indices=full_adj.indices,
                    full_adj_indptr=full_adj.indptr,
                    full_adj_shape=full_adj.shape,
                    # train_part_adj_data=train_part_adj.data, train_part_adj_indices=train_part_adj.indices, train_part_adj_indptr=train_part_adj.indptr, train_part_adj_shape=train_part_adj.shape,
                    **train_part_adj_data_dict,
                    **train_part_adj_indices_dict,
                    **train_part_adj_indptr_dict,
                    **train_part_adj_shape_dict,
                    full_part_adj_data=full_part_adj.data,
                    full_part_adj_indices=full_part_adj.indices,
                    full_part_adj_indptr=full_part_adj.indptr,
                    full_part_adj_shape=full_part_adj.shape,
                    feats=feats,
                    train_feats=train_feats,
                    test_feats=test_feats,
                    labels=labels,
                    train_data=train_data,
                    val_data=val_data,
                    test_data=test_data,
                    # train_parts=train_parts,
                    **train_parts_dict,
                    full_parts=full_parts
                )
            else:
                # Not saving feats/train_feats/test_feats
                np.savez(
                    fwrite,
                    num_data=num_data,
                    train_adj_data=train_adj.data,
                    train_adj_indices=train_adj.indices,
                    train_adj_indptr=train_adj.indptr,
                    train_adj_shape=train_adj.shape,
                    full_adj_data=full_adj.data,
                    full_adj_indices=full_adj.indices,
                    full_adj_indptr=full_adj.indptr,
                    full_adj_shape=full_adj.shape,
                    # train_part_adj_data=train_part_adj.data, train_part_adj_indices=train_part_adj.indices, train_part_adj_indptr=train_part_adj.indptr, train_part_adj_shape=train_part_adj.shape,
                    **train_part_adj_data_dict,
                    **train_part_adj_indices_dict,
                    **train_part_adj_indptr_dict,
                    **train_part_adj_shape_dict,
                    full_part_adj_data=full_part_adj.data,
                    full_part_adj_indices=full_part_adj.indices,
                    full_part_adj_indptr=full_part_adj.indptr,
                    full_part_adj_shape=full_part_adj.shape,
                    # feats=feats, train_feats=train_feats, test_feats=test_feats,
                    labels=labels,
                    train_data=train_data,
                    val_data=val_data,
                    test_data=test_data,
                    # train_parts=train_parts,
                    **train_parts_dict,
                    full_parts=full_parts
                )

    return (
        num_data,
        train_adj,
        full_adj,
        train_part_adj_list,
        full_part_adj,
        feats,
        train_feats,
        test_feats,
        labels,
        train_data,
        val_data,
        test_data,
        train_parts_list,
        full_parts,
    )


def load_data_new(root, dataset, partition_list=[1]):
    gcn_datasets = set(["cora", "citeseer", "pubmed", "nell"])
    nofeat_datasets = set(["amazon", "amazon-0.1", "amazon-0.3"])
    classlist_datasets = set(["amazon2M", "amazon2M-47"])
    if dataset in gcn_datasets:
        print("gcn data not yet supported")
        return
    elif dataset in nofeat_datasets:
        return load_graphsage_data(
            "{}/data/{}".format(root, dataset),
            partition_list=partition_list,
            has_feats=False,
            normalize=False,
        )
    elif dataset in classlist_datasets:
        return load_graphsage_data(
            "{}/data/{}".format(root, dataset),
            partition_list=partition_list,
            has_feats=True,
            class_list=True,
        )
    else:
        return load_graphsage_data(
            "{}/data/{}".format(root, dataset),
            partition_list=partition_list,
            has_feats=True,
        )
