""" PPIDataset for inductive learning. """
import json
import numpy as np
import networkx as nx
from networkx.readwrite import json_graph
import os

from dgl.data.dgl_dataset import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url, save_graphs, save_info, load_info, load_graphs
from dgl import backend as F
from dgl.convert import from_networkx


class PPIDataset(DGLBuiltinDataset):
    r"""Protein-Protein Interaction dataset for inductive node classification

    A toy Protein-Protein Interaction network dataset. The dataset contains
    24 graphs. The average number of nodes per graph is 2372. Each node has
    50 features and 121 labels. 20 graphs for training, 2 for validation
    and 2 for testing.

    Reference: `<http://snap.stanford.edu/graphsage/>`_

    Statistics:

    - Train examples: 20
    - Valid examples: 2
    - Test examples: 2

    Parameters
    ----------
    mode : str
        Must be one of ('train', 'valid', 'test').
        Default: 'train'
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset.
        Default: False
    verbose : bool
        Whether to print out progress information.
        Default: True.
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.

    Attributes
    ----------
    num_labels : int
        Number of labels for each node
    labels : Tensor
        Node labels
    features : Tensor
        Node features

    Examples
    --------
    >>> dataset = PPIDataset(mode='valid')
    >>> num_labels = dataset.num_labels
    >>> for g in dataset:
    ....    feat = g.ndata['feat']
    ....    label = g.ndata['label']
    ....    # your code here
    >>>
    """

    def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
        _url = _get_dgl_url("dataset/ppi.zip")
        super(PPIDataset, self).__init__(
            name="ppi",
            url=_url,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    def process_graph(self, mode):
        graph_file = os.path.join(self.save_path, "{}_graph.json".format(mode))
        label_file = os.path.join(self.save_path, "{}_labels.npy".format(mode))
        feat_file = os.path.join(self.save_path, "{}_feats.npy".format(mode))
        graph_id_file = os.path.join(self.save_path, "{}_graph_id.npy".format(mode))

        g_data = json.load(open(graph_file))
        labels = np.load(label_file)
        feats = np.load(feat_file)
        graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
        graph_id = np.load(graph_id_file)

        # lo, hi means the range of graph ids for different portion of the dataset,
        # 20 graphs for training, 2 for validation and 2 for testing.
        lo, hi = 1, 21
        if mode == "valid":
            lo, hi = 21, 23
        elif mode == "test":
            lo, hi = 23, 25

        graph_masks = []
        graphs = []
        for g_id in range(lo, hi):
            g_mask = np.where(graph_id == g_id)[0]
            graph_masks.append(g_mask)
            g = graph.subgraph(g_mask)
            g.ndata["feat"] = F.tensor(feats[g_mask], dtype=F.data_type_dict["float32"])
            g.ndata["label"] = F.tensor(
                labels[g_mask], dtype=F.data_type_dict["float32"]
            )
            graphs.append(g)
        return graph, feats, labels, graphs

    def process(self):
        self.train_graph, train_feats, train_labels, train_graphs = self.process_graph(
            "train"
        )
        self.valid_graph, valid_feats, valid_labels, valid_graphs = self.process_graph(
            "valid"
        )
        self.test_graph, test_feats, test_labels, test_graphs = self.process_graph(
            "test"
        )

        self._labels = np.concatenate([train_labels, valid_labels, test_labels], axis=0)
        self._feats = np.concatenate([train_feats, valid_feats, test_feats], axis=0)
        self.graphs = train_graphs + valid_graphs + test_graphs

    def has_cache(self):
        graph_list_path = os.path.join(self.save_path, "dgl_graph_list.bin")
        # g_path = os.path.join(self.save_path, 'dgl_graph.bin')
        info_path = os.path.join(self.save_path, "info.pkl")
        return os.path.exists(graph_list_path) and os.path.exists(info_path)

    def save(self):
        graph_list_path = os.path.join(self.save_path, "dgl_graph_list.bin")
        # g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
        info_path = os.path.join(self.save_path, "info.pkl")
        save_graphs(graph_list_path, self.graphs)
        # save_graphs(g_path, self.graph)
        save_info(info_path, {"labels": self._labels, "feats": self._feats})

    def load(self):
        graph_list_path = os.path.join(self.save_path, "dgl_graph_list.bin")
        # g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
        info_path = os.path.join(self.save_path, "info.pkl")
        self.graphs = load_graphs(graph_list_path)[0]
        # g, _ = load_graphs(g_path)
        # self.graph = g[0]
        info = load_info(info_path)
        self._labels = info["labels"]
        self._feats = info["feats"]

    def get_idx_split(self):
        train_idx = np.arange(20)
        valid_idx = np.arange(2) + 20
        test_idx = np.arange(2) + 2 + 20

        return {
            "train": F.tensor(train_idx, dtype=F.data_type_dict["int64"]),
            "valid": F.tensor(valid_idx, dtype=F.data_type_dict["int64"]),
            "test": F.tensor(test_idx, dtype=F.data_type_dict["int64"]),
        }

    @property
    def num_classes(self):
        return 121

    def __len__(self):
        """Return number of samples in this dataset."""
        return len(self.graphs)

    def __getitem__(self, item):
        """Get the item^th sample.

        Parameters
        ---------
        item : int
            The sample index.

        Returns
        -------
        :class:`dgl.DGLGraph`
            graph structure, node features and node labels.

            - ``ndata['feat']``: node features
            - ``ndata['label']``: node labels
        """
        if self._transform is None:
            return self.graphs[item]
        else:
            return self._transform(self.graphs[item])


class LegacyPPIDataset(PPIDataset):
    """Legacy version of PPI Dataset"""

    def __getitem__(self, item):
        """Get the item^th sample.

        Paramters
        ---------
        idx : int
            The sample index.

        Returns
        -------
        (dgl.DGLGraph, Tensor, Tensor)
            The graph, features and its label.
        """
        if self._transform is None:
            g = self.graphs[item]
        else:
            g = self._transform(self.graphs[item])
        return g, g.ndata["feat"], g.ndata["label"]
