import sys; import os
sys.path.append(os.getcwd())
import torch as t
import torch_geometric.data as tgd
import torch_geometric.utils as tgu
import torch_geometric.transforms as transforms
import numpy as np
from torch_geometric.utils.convert import to_scipy_sparse_matrix
import dataset.meta as DM
import math

def invert_ix_map(xs):
    """
    example:
    if xs = [3, 2, 7, 4], then
    invert_ix_map(xs) == [-1, -1, 1, 0, 3, -1, -1, 2]

    then invert_ixs_map[xs[i]] == i
    """
    inv_nums = -t.ones(xs.max()+1, device=xs.device).long()
    for i, num in enumerate(xs):
        inv_nums[num] = i
    return inv_nums
assert t.allclose(invert_ix_map(t.tensor([3, 2, 7, 4])), t.tensor([-1, -1, 1, 0, 3, -1, -1, 2]))


def prepend_duplicate_rows(adj_sp, ixs_to_duplicate):
    device = adj_sp.device
    inv_ixs_to_duplicate = invert_ix_map(ixs_to_duplicate)

    edge_index = adj_sp.indices()
    # remove self-loops
    non_self_loops = edge_index[0] != edge_index[1]
    indices = edge_index[:, non_self_loops]

    # we assume the edge_index given has duplicate pairs (a,b) and (b,a)
    # we remove those doubles where! a.k.a. force_undirected: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/dropout.html#dropout_edge
    ## this is un-done at the end of the function
    indices = indices[:, indices[0] < indices[1]]


    num_new_nodes = len(ixs_to_duplicate)
    indices_ii = t.zeros(2, 0, device=device); indices_it = t.zeros(2, 0, device=device) # fill these edges
    indices_tt = indices.clone() + num_new_nodes # increase index of existing nodes

    for ixd in ixs_to_duplicate:
        mask = indices[0] == ixd

        _indices_it = indices_tt[:, mask]
        _indices_it[0,:] = inv_ixs_to_duplicate[ixd]
        indices_it = t.cat((indices_it, _indices_it), dim=1)

        ## make sure that all ind nodes are connected to their corresponding train/test nodes!
        extra_edge = t.tensor([inv_ixs_to_duplicate[ixd], ixd + num_new_nodes], device=adj_sp.device).unsqueeze(-1)
        indices_it = t.cat((indices_it, extra_edge), dim=1)



        mask = t.logical_and(mask, t.isin(indices[1], ixs_to_duplicate[ixd < ixs_to_duplicate]))
        _indices_ii = indices[:, mask]
        _indices_ii[0,:] = inv_ixs_to_duplicate[ixd]
        _indices_ii[1,:] = inv_ixs_to_duplicate[_indices_ii[1,:]]
        indices_ii = t.cat([indices_ii, _indices_ii], dim=1)



    indices = t.cat([indices_ii, indices_it, indices_tt], dim=1)
    indices = t.cat([indices, indices.flip(0)], dim=1) ## replace duplicate edges!
    new_size = adj_sp.size(0) + len(ixs_to_duplicate) # new size of the adj matrix
    # values = t.ones(indices.size(1)) # must be same size as num edges (nnzs)
    # return t.sparse_coo_tensor(indices=indices,
                            #    values=values,
                            #    size=(new_size, new_size))

    new_adj_sp = calc_normalize_adj(indices, new_size).coalesce()
    return new_adj_sp
    # return laurence_adj_hack(indices, new_size)

"""
utils adapted from https://github.com/tkipf/pygcn for creating normalized adj
"""
def _sparse_mx_to_torch_sparse_tensor(sparse_mx, num_nodes=None):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = t.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = t.from_numpy(sparse_mx.data)
    if num_nodes is None:
        shape = t.Size(sparse_mx.shape)
    else:
        assert num_nodes >= sparse_mx.shape[0] and num_nodes >= sparse_mx.shape[1], "num_nodes is wrong"
        shape = t.Size((num_nodes, num_nodes))
    return t.sparse.FloatTensor(indices, values, shape)
def calc_normalize_adj(edge_ix, num_nodes, lmbda=1.):
    """
    given edge indices, returns
    D'^-1/2 @ A' @ D'^-1/2
    where A' = A + lmbda * I, and D' = degree matrix of A'
    """
    adj_sp = to_scipy_sparse_matrix(edge_ix, num_nodes=num_nodes)
    import numpy as np
    from scipy.sparse import coo_matrix, eye

    def normalize_adj(adj):
        S = adj.shape[0]
        adj = adj.maximum(adj.T)
        adj = adj + eye(S) * lmbda # custom self-loop

        # Compute the degree matrix
        deg = np.sum(adj, axis=1).A1
        deg_inv_sqrt = np.power(deg, -0.5)
        deg_inv_sqrt[np.isinf(deg_inv_sqrt)] = 0

        # Create diagonal matrices for normalization
        deg_inv_sqrt_mat_l = coo_matrix((deg_inv_sqrt, (np.arange(S), np.arange(S))))
        deg_inv_sqrt_mat_r = coo_matrix((deg_inv_sqrt, (np.arange(S), np.arange(S))))

        # Normalize the adjacency matrix
        norm_adj = deg_inv_sqrt_mat_l.dot(adj).dot(deg_inv_sqrt_mat_r)
        return norm_adj
    res = normalize_adj(adj_sp)
    res = _sparse_mx_to_torch_sparse_tensor(res)
    return res

def Ahat_interp_id(edge_ix, num_nodes, lmbda=0.65):
    """
    calculates (1-lmbda) * normalize_adj(edge_ixs) + lmbda * eye
    """
    adj_sp = to_scipy_sparse_matrix(edge_ix, num_nodes=num_nodes)
    import numpy as np
    from scipy.sparse import coo_matrix, eye

    S = adj_sp.shape[0]
    def normalize_adj(adj):
        adj = adj.maximum(adj.T)
        adj = adj + eye(S) * lmbda # custom self-loop

        # Compute the degree matrix
        deg = np.sum(adj, axis=1).A1
        deg_inv_sqrt = np.power(deg, -0.5)
        deg_inv_sqrt[np.isinf(deg_inv_sqrt)] = 0

        # Create diagonal matrices for normalization
        deg_inv_sqrt_mat_l = coo_matrix((deg_inv_sqrt, (np.arange(S), np.arange(S))))
        deg_inv_sqrt_mat_r = coo_matrix((deg_inv_sqrt, (np.arange(S), np.arange(S))))

        # Normalize the adjacency matrix
        norm_adj = deg_inv_sqrt_mat_l.dot(adj).dot(deg_inv_sqrt_mat_r)
        return norm_adj
    res = normalize_adj(adj_sp) * (1-lmbda) +  lmbda * eye(S)
    res = _sparse_mx_to_torch_sparse_tensor(res)
    return res

default_transforms = [transforms.ToUndirected(), transforms.GCNNorm(add_self_loops=True)]

def edges_to_adj(edge_ix, edge_w, num_nodes=None):
    """recover a torch sparse matrix from a list of edges and weights"""
    num_nodes = num_nodes or edge_ix.max().item() + 1
    return tgu.to_torch_sparse_tensor(edge_ix, edge_w, size=num_nodes)
def get_adj(x: tgd.Data):
    num_nodes = x.x.size(0) # can't use x.edge_index.max() + 1 because some nodes are not connected to anything
    return edges_to_adj(x.edge_index, x.edge_weight, num_nodes=num_nodes)


def adjust_homophily_ratio(edge_index, y, desired_ratio=None):

    # remove self-loops
    non_self_loops = edge_index[0] != edge_index[1]
    indices = edge_index[:, non_self_loops]

    # we assume the edge_index given has duplicate pairs (a,b) and (b,a)
    # we remove those doubles where! a.k.a. force_undirected: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/dropout.html#dropout_edge
    indices = indices[:, indices[0] < indices[1]]

    num_edges = indices.size(1)
    required_num_like_edges = math.floor(desired_ratio * num_edges)
    num_like_edges = (y[indices[0]] == y[indices[1]]).sum()
    num_nodes = len(y)

    to_change = num_like_edges - required_num_like_edges

    if to_change < 0: # need to replace non-like edges with like edges
        ## first sample some non-like edges to be replaced
        non_like_ixs = y[indices[0]] != y[indices[1]]
        non_like_ixs = t.arange(num_edges)[non_like_ixs]
        ixs_to_change = non_like_ixs[t.randperm(len(non_like_ixs))][:abs(to_change)]

        # sanity
        assert len(ixs_to_change) == abs(to_change)
        start_nodes = indices[0, ixs_to_change]; end_nodes = indices[1, ixs_to_change]
        assert (y[start_nodes] != y[end_nodes]).all()

        # then we will update the end nodes of these edges so that the end
        # node __is__ in the same class as the start node
        start_nodes = indices[0, ixs_to_change]
        end_nodes = t.empty(len(ixs_to_change), dtype=t.long)
        modified = 0
        for out_class in y[start_nodes].unique():
            to_modify = y[start_nodes] == out_class
            n_to_modify = to_modify.sum()
            possible_end_nodes = t.arange(num_nodes)[y == out_class]
            random_end_node_sample = t.multinomial(t.ones(len(possible_end_nodes)), n_to_modify, replacement=True)
            new_end_nodes = possible_end_nodes[random_end_node_sample]
            end_nodes[to_modify] = new_end_nodes
            modified += len(new_end_nodes)
        assert modified == abs(to_change)

        indices[0, ixs_to_change] = start_nodes
        indices[1, ixs_to_change] = end_nodes # actually rewrite the end nodes

        for ix in ixs_to_change:
            assert y[indices[0, ix]] == y[indices[1, ix]], "should be a like edge"


    elif to_change > 0: # need to replace some like edges with non-like edges
        ## first sample some like edges to be replaced
        like_ixs = y[indices[0]] == y[indices[1]]
        like_ixs = t.arange(num_edges)[like_ixs]
        ixs_to_change = like_ixs[t.randperm(len(like_ixs))][:abs(to_change)]

        # sanity
        assert len(ixs_to_change) == abs(to_change)
        start_nodes = indices[0, ixs_to_change]; end_nodes = indices[1, ixs_to_change]
        assert y[start_nodes].eq(y[end_nodes]).all(), "should be a like edge"
        # then we will update the end nodes of these edges so that the end
        # node is not in the same class of the start node
        start_nodes = indices[0, ixs_to_change]
        end_nodes = t.empty(len(ixs_to_change), dtype=t.long)
        modified = 0
        for out_class in y[start_nodes].unique():
            to_modify = y[start_nodes] == out_class
            n_to_modify = to_modify.sum()
            possible_end_nodes = t.arange(num_nodes)[y != out_class]
            assert out_class not in y[possible_end_nodes]
            random_end_node_sample = t.multinomial(t.ones(len(possible_end_nodes)), n_to_modify, replacement=True)
            new_end_nodes = possible_end_nodes[random_end_node_sample]
            end_nodes[to_modify] = new_end_nodes
            assert out_class not in y[end_nodes[to_modify]]
            modified += len(new_end_nodes)
        assert modified == abs(to_change)

        indices[0, ixs_to_change] = start_nodes
        indices[1, ixs_to_change] = end_nodes # actually rewrite the end nodes


        for ix in ixs_to_change:
            assert y[indices[0, ix]] != y[indices[1, ix]], "should not be a like edge"

    num_edges = indices.size(1)
    num_like_edges = (y[indices[0]] == y[indices[1]]).sum()
    print(f"new homophily ratio = {num_like_edges / num_edges:.2f}")

    # we re-duplicate the edges
    # i.e. force_undirected
    indices = t.cat([indices, indices.flip(0)], dim=1)


    num_edges = indices.size(1) / 2
    num_like_edges = (y[indices[0]] == y[indices[1]]).sum() / 2
    new_homophily_ratio = num_like_edges / num_edges
    assert (desired_ratio - new_homophily_ratio) <= 0.01, f"new homophily ratio isn't correct!"

    return indices
    # return adj_sp

class GraphDataset:
    """
    dataset agnostic dataclass for graph datasets
    """
    def __init__(self, is_classification_task,
                       is_node_task,
                       is_multi_graph,
                       num_features,
                       num_classes,
                       # only for single graph
                       X=None,
                       adj_sp=None,
                       y=None,
                       train_mask=None, val_mask=None, test_mask=None,
                       # only for multi graph
                       train_ds=None,
                       val_ds=None,
                       test_ds=None,
                       # for datasets with many splits, should be an iterator
                       many_ds=None,
                       edge_index=None,
                       Pi=None,
                       mc_samples=None,
                       chunk_size=None,
                       minibatch_size=None):
        # metadata about dataset
        self.is_classification_task = is_classification_task
        self.is_node_task = is_node_task
        self.num_features = num_features
        self.num_classes = num_classes
        self.is_multi_graph = is_multi_graph
        # the data itself
        self.train_mask = train_mask
        self.val_mask = val_mask
        self.test_mask = test_mask
        self.adj_sp = adj_sp
        self.X = X
        self.y = y
        self.train_ds = train_ds
        self.val_ds = val_ds
        self.test_ds = test_ds
        self.many_ds = many_ds
        self.edge_index = edge_index

        self.Pi = Pi
        self.chunk_size = chunk_size
        self.mc_samples = mc_samples
        self.minibatch_size = minibatch_size

def get_adjusted_homophily(edge_ix, y):
    """
    see section 2.4 of https://openreview.net/pdf?id=m7PIJWOdlY
    """

    ## remove self loops
    non_self_loops = edge_ix[0] != edge_ix[1]
    edge_ix = edge_ix[:, non_self_loops]

    # we assume the edge_index given has duplicate pairs (a,b) and (b,a)
    # this makes things much more convenient

    adj_sp = to_scipy_sparse_matrix(edge_ix, num_nodes=len(y))




    C = len(y.unique()) # num of classes
    d = t.tensor(adj_sp.sum(1)).long().squeeze() # degree of each node

    # distribution of class labels
    p_k = t.tensor([(y == c).sum() for c in range(C)])
    p_k = p_k / len(y) # this is a prob vector

    D_k = t.tensor([d[y == c].sum() for c in range(C)])

    # degree weighted distribution of class labels
    _2E = d.sum() # twice the number of edges
    pbar_k = D_k / _2E

    ## edge homophily ratio
    num_like_edges_x2 = (y[edge_ix[0]] == y[edge_ix[1]]).sum() # twice the number of like edges
    h_edge = num_like_edges_x2 / _2E

    h_adj = (h_edge - pbar_k.square().sum()) / (1 - pbar_k.square().sum())
    return h_edge, h_adj




## some datasets (e.g. reddit and arxiv) are slow to load on network storage
## so we cache them here
DS_CACHE = dict()

def get_dataset(name: str, normalize_features=False, print_stats=False, edit_hr=None, **kwargs) -> GraphDataset:
    if name in ['cora', 'citeseer', 'pubmed']:
        f=get_planetoid
    elif name in ['chameleon', 'squirrel']:
        f=get_wikipedia_network
    elif name in ['roman-empire', 'tolokers', 'minesweeper', 'amazon-ratings']:
        f=get_heterophilous
    elif name in ['reddit']:
        f=get_reddit
    elif name in ['arxiv']:
        f=get_arxiv
    else:
        raise NotImplementedError(f"unknown dataset '{name}'")
    trans = default_transforms + [transforms.NormalizeFeatures()] if normalize_features else default_transforms
    transform = transforms.Compose(trans)
    res = f(name=DM.correct_names[name], transform=transform, **kwargs)
    if print_stats:
        ## first remove self-loops
        edges = res.edge_index
        non_self_loops = edges[0] != edges[1]
        indices = edges[:, non_self_loops]
        num_edges = indices.size(-1) // 2 # double counted, so div by 2
        y = res.y

        num_like_edges = 0
        num_like_edges = (y[indices[0]] == y[indices[1]]).sum() // 2 # avoid double counting
        homophily_ratio = num_like_edges / num_edges # ratio of 'like edges' to all edges

        print("num_edges =", num_edges)
        print("num_nodes =", res.X.size(0))
        print(f"homophily_ratio = {homophily_ratio.item():.2f}")
        print("num_classes =", res.num_classes)
        print("num_features =", res.num_features)

        h_edge, h_adj = get_adjusted_homophily(res.edge_index, res.y)
        print(f"homophily_ratio (edge) = {h_edge:.2f}")
        print(f"homophily_ratio (adj) = {h_adj:.2f}")

        total = res.train_mask.sum() + res.val_mask.sum() + res.test_mask.sum()
        train_prop = res.train_mask.sum().item() / total
        val_prop = res.val_mask.sum().item() / total
        test_prop = res.test_mask.sum().item() / total
        print(f"train/val/test = {train_prop:.2f}/{val_prop:.2f}/{test_prop:.2f}")

    if edit_hr is not None:
        edge_index = adjust_homophily_ratio(res.edge_index, res.y, desired_ratio=edit_hr)
        adj_sp = calc_normalize_adj(edge_index, len(res.y))
        res.edge_index = edge_index
        res.adj_sp = adj_sp
    return res

## single graph datasets
########################
def get_planetoid(name='cora', split='public', citeseer_float32_fix=False, transform=None):
    assert split == 'public', f"expected public split, got '{split}'"
    if name not in DS_CACHE:
        from torch_geometric.datasets import Planetoid
        d = Planetoid(root="_datasets", name=name, split='public', transform=transform)
        X = d[0].x
        y = d[0].y
        adj_sp = get_adj(d[0])
        res = GraphDataset(is_classification_task=True,
                            is_node_task=True,
                            is_multi_graph=False,
                            num_features=X.size(1),
                            num_classes=len(y.unique()),
                            X=X,
                            adj_sp=adj_sp,
                            y=y,
                            train_mask=d[0].train_mask,
                            val_mask=d[0].val_mask,
                            test_mask=d[0].test_mask,
                            edge_index=d[0].edge_index,
                            Pi=d[0].train_mask.sum())
        DS_CACHE[name] = res
    return DS_CACHE[name]

def get_wikipedia_network(name='chameleon', split=None, transform=False, **kwargs):
    if name not in DS_CACHE:
        from torch_geometric.datasets import WikipediaNetwork
        d = WikipediaNetwork(root="_datasets", name=name, transform=transform)
        DS_CACHE[name] = d
    d = DS_CACHE[name]

    X = d[0].x
    y = d[0].y
    adj_sp = get_adj(d[0])
    assert split.startswith('val-') or split.startswith('test-')
    split_int = int(split.split('-')[-1])
    assert split_int in list(range(10))
    train_mask = d[0].train_mask[:, split_int]
    val_mask = d[0].val_mask[:, split_int]
    test_mask = d[0].test_mask[:, split_int]
    assert train_mask.sum() + val_mask.sum() + test_mask.sum() == len(y)
    """
    sanity checks on the splits
    """
    num_classes = len(y.unique())
    assert (train_mask & val_mask & test_mask).sum().item() == 0, "train/val/test masks don't overlap"
    assert len(y[train_mask].unique()) == num_classes, "train set contains all classes"
    assert len(y[val_mask].unique()) == num_classes, "val set contains all classes"
    assert len(y[test_mask].unique()) == num_classes, "test set contains all classes"
    return GraphDataset(is_classification_task=True,
                        is_node_task=True,
                        is_multi_graph=False,
                        num_features=X.size(1),
                        num_classes=num_classes,
                        X=X,
                        adj_sp=adj_sp,
                        y=y,
                        edge_index=d[0].edge_index,
                        train_mask=train_mask,
                        val_mask=val_mask,
                        test_mask=test_mask)

def get_heterophilous(name='tolokers', split=None, transform=None, **kwargs):
    if name not in DS_CACHE:
        from torch_geometric.datasets import HeterophilousGraphDataset
        d = HeterophilousGraphDataset(root="_datasets", name=name, transform=transform)
        DS_CACHE[name] = d
    d = DS_CACHE[name]
    X = d[0].x
    y = d[0].y
    adj_sp = get_adj(d[0])

    assert split.startswith('val-') or split.startswith('test-'), f"expected split to start with 'val-' or 'test-', got '{split}'"
    split_int = int(split.split('-')[-1])
    assert split_int in list(range(10))
    train_mask = d.train_mask[:, split_int]
    val_mask = d.val_mask[:, split_int]
    test_mask = d.test_mask[:, split_int]
    assert train_mask.sum() + val_mask.sum() + test_mask.sum() == len(y)
    """
    sanity checks on the splits
    """
    num_classes = len(y.unique())
    assert (train_mask & val_mask & test_mask).sum().item() == 0, "train/val/test masks don't overlap"
    return GraphDataset(is_classification_task=True,
                        is_node_task=True,
                        is_multi_graph=False,
                        num_features=X.size(1),
                        num_classes=num_classes,
                        X=X,
                        adj_sp=adj_sp,
                        y=y,
                        train_mask=train_mask,
                        val_mask=val_mask,
                        test_mask=test_mask,
                        edge_index=d[0].edge_index,
                        chunk_size=10)

def get_reddit(name, split=None, transform=None, **kwargs):
    assert split == 'public', f"there is only one reddit split"
    if name not in DS_CACHE:
        """
        loading the reddit dataset via torch_geometric.Reddit is very slow, so we cache the
        result in a pickle file for slightly better perf (though it is still slow :()
        """
        from torch_geometric.datasets import Reddit
        from pathlib import Path
        import pickle as pk
        pkl_path = Path("_datasets_pkl/reddit")
        if not (pkl_path / 'data.pkl').exists():
            d = Reddit(root="_datasets", transform=transform)
            X = d[0].x
            y = d[0].y
            adj_sp = get_adj(d[0]).coalesce()
            edge_index = d[0].edge_index

            train_mask = d.train_mask
            val_mask = d.val_mask
            test_mask = d.test_mask
            assert train_mask.sum() + val_mask.sum() + test_mask.sum() == len(y)

            pkl_dict = dict(X=X, y=y, adj_sp=adj_sp, edge_index=edge_index,
                            train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
            pkl_path.mkdir(parents=True, exist_ok=True)
            with open(pkl_path / 'data.pkl', 'wb') as f:
                pk.dump(pkl_dict, f)
        else:
            with open(pkl_path / 'data.pkl', 'rb') as f:
                pkl_dict = pk.load(f)
            X = pkl_dict['X']
            y = pkl_dict['y']
            adj_sp = pkl_dict['adj_sp'].coalesce()
            train_mask = pkl_dict['train_mask']
            val_mask = pkl_dict['val_mask']
            edge_index = pkl_dict['edge_index']
            test_mask = pkl_dict['test_mask']

        """
        sanity checks on the splits
        """
        num_classes = len(y.unique())
        assert (train_mask & val_mask & test_mask).sum().item() == 0, "train/val/test masks don't overlap"
        res = GraphDataset(is_classification_task=True,
                            is_node_task=True,
                            is_multi_graph=False,
                            num_features=X.size(1),
                            num_classes=num_classes,
                            X=X,
                            adj_sp=adj_sp,
                            y=y,
                            train_mask=train_mask,
                            val_mask=val_mask,
                            test_mask=test_mask,
                            edge_index=edge_index,
                            chunk_size=10, mc_samples=50,
                            minibatch_size=116483 # ~ num_nodes / 2
                            )
        DS_CACHE[name] = res
    return DS_CACHE[name]


def get_arxiv(name, split=None, transform=None, **kwargs):
    """adapted from https://github.com/niuzehao/gnn-gp/blob/master/src/datasets.py"""
    assert split == 'public', f"there is only one arxiv split"
    if name not in DS_CACHE:
        import os.path as osp

        import torch

        # from torch_sparse import coalesce
        from torch_geometric.utils import coalesce
        from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data

        class OGB_arxiv(InMemoryDataset):
            """
            The ogbn dataset from the `"Open Graph Benchmark: Datasets for
            Machine Learning on Graphs" <https://arxiv.org/abs/2005.00687>` paper.
            ogbn-arxiv is a paper citation network of arXiv papers.
            Each node is an ArXiv paper and each directed edge indicates that one paper cites another one.
            Node features are 128-dimensional vector obtained by averaging the WORD2VEC embeddings of words in its title and abstract.
            The task is to predict the 40 subject areas of ARXIV CS papers.

            Args:
                root (string): Root directory where the dataset should be saved.
                transform (callable, optional): A function/transform that takes in an
                    :obj:`torch_geometric.data.HeteroData` object and returns a
                    transformed version. The data object will be transformed before
                    every access. (default: `None`)
                pre_transform (callable, optional): A function/transform that takes in
                    an :obj:`torch_geometric.data.HeteroData` object and returns a
                    transformed version. The data object will be transformed before
                    being saved to disk. (default: `None`)
            """

            url = "http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip"

            def __init__(self, root: str, transform = None, pre_transform = None):
                super().__init__(root, transform, pre_transform)
                self.data, self.slices = torch.load(self.processed_paths[0])

            @property
            def num_classes(self) -> int:
                return int(self.data.y.max()) + 1

            @property
            def raw_dir(self) -> str:
                return osp.join(self.root, "arxiv", "raw")

            @property
            def processed_dir(self) -> str:
                return osp.join(self.root, "arxiv", "processed")

            @property
            def raw_file_names(self) -> str:
                file_names = ["node-feat.csv.gz", "node-label.csv.gz", "edge.csv.gz", "train.csv.gz", "valid.csv.gz", "test.csv.gz"]
                return file_names

            @property
            def processed_file_names(self) -> str:
                return "data.pt"

            def download(self):
                import os, shutil
                path = download_url(self.url, self.raw_dir)
                extract_zip(path, self.raw_dir)
                for file_name in self.raw_file_names[:3]:
                    path = osp.join(self.raw_dir, "arxiv", "raw", file_name)
                    shutil.move(path, self.raw_dir)
                for file_name in self.raw_file_names[3:]:
                    path = osp.join(self.raw_dir, "arxiv", "split", "time", file_name)
                    shutil.move(path, self.raw_dir)
                shutil.rmtree(osp.join(self.raw_dir, "arxiv"))
                os.remove(osp.join(self.raw_dir, "arxiv.zip"))

            def process(self):
                import pandas as pd
                import numpy as np

                values = pd.read_csv(self.raw_paths[0], compression="gzip", header=None, dtype=np.float32).values
                x = torch.from_numpy(values)

                values = pd.read_csv(self.raw_paths[1], compression="gzip", header=None, dtype=np.int64).values
                y = torch.from_numpy(values).view(-1)

                # A symmetrization is required for the graph.
                values = pd.read_csv(self.raw_paths[2], compression="gzip", header=None, dtype=np.int64).values
                values = torch.from_numpy(values).t().contiguous()
                edge_index = torch.unique(torch.cat((values, values.flip(0)), dim=1), dim=1)
                edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))

                data = Data(x=x, edge_index=edge_index, y=y)

                for f, v in [("train", "train"), ("valid", "val"), ("test", "test")]:
                    values = pd.read_csv(f"{self.raw_dir}/{f}.csv.gz", compression="gzip", header=None, dtype=np.int64).values
                    idx = torch.from_numpy(values).view(-1)
                    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
                    mask[idx] = True
                    data[f"{v}_mask"] = mask

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                torch.save(self.collate([data]), self.processed_paths[0])
        from pathlib import Path
        import pickle as pk

        pkl_path = Path("_datasets_pkl/arxiv")
        if not (pkl_path / 'data.pkl').exists():
            d = OGB_arxiv(root="_datasets/", transform=transform)
            X = d[0].x
            y = d[0].y
            edge_index = d[0].edge_index
            adj_sp = get_adj(d[0]).coalesce()

            train_mask = d.train_mask
            val_mask = d.val_mask
            test_mask = d.test_mask
            assert train_mask.sum() + val_mask.sum() + test_mask.sum() == len(y)

            pkl_dict = dict(X=X, y=y, adj_sp=adj_sp,
                            train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, edge_index=edge_index)
            pkl_path.mkdir(parents=True, exist_ok=True)
            with open(pkl_path / 'data.pkl', 'wb') as f:
                pk.dump(pkl_dict, f)
        else:
            with open(pkl_path / 'data.pkl', 'rb') as f:
                pkl_dict = pk.load(f)
            X = pkl_dict['X']
            y = pkl_dict['y']
            adj_sp = pkl_dict['adj_sp'].coalesce()
            train_mask = pkl_dict['train_mask']
            val_mask = pkl_dict['val_mask']
            test_mask = pkl_dict['test_mask']
            edge_index = pkl_dict['edge_index']

        """
        sanity checks on the splits
        """
        num_classes = len(y.unique())
        assert (train_mask & val_mask & test_mask).sum().item() == 0, "train/val/test masks don't overlap"
        res = GraphDataset(is_classification_task=True,
                            is_node_task=True,
                            is_multi_graph=False,
                            num_features=X.size(1),
                            num_classes=num_classes,
                            X=X,
                            adj_sp=adj_sp,
                            y=y,
                            train_mask=train_mask,
                            val_mask=val_mask,
                            test_mask=test_mask,
                            edge_index=edge_index,
                            chunk_size=10, mc_samples=100,
                            minibatch_size=84672 # ~ num_nodes / 2
                            )
        DS_CACHE[name] = res
    return DS_CACHE[name]

if __name__ == '__main__':
    ds = get_dataset('chameleon', split='test-100-9')