"""Basic DGL Dataset
"""

from __future__ import absolute_import

import abc
import hashlib
import os
import sys
import traceback

from ..utils import retry_method_with_fix
from .utils import download, extract_archive, get_download_dir, makedirs


class DGLDataset(object):
    r"""The basic DGL dataset for creating graph datasets.
    This class defines a basic template class for DGL Dataset.
    The following steps will be executed automatically:

      1. Check whether there is a dataset cache on disk
         (already processed and stored on the disk) by
         invoking ``has_cache()``. If true, goto 5.
      2. Call ``download()`` to download the data if ``url`` is not None.
      3. Call ``process()`` to process the data.
      4. Call ``save()`` to save the processed dataset on disk and goto 6.
      5. Call ``load()`` to load the processed dataset from disk.
      6. Done.

    Users can overwite these functions with their
    own data processing logic.

    Parameters
    ----------
    name : str
        Name of the dataset
    url : str
        Url to download the raw dataset. Default: None
    raw_dir : str
        Specifying the directory that will store the
        downloaded data or the directory that
        already stores the input data.
        Default: ~/.dgl/
    save_dir : str
        Directory to save the processed dataset.
        Default: same as raw_dir
    hash_key : tuple
        A tuple of values as the input for the hash function.
        Users can distinguish instances (and their caches on the disk)
        from the same dataset class by comparing the hash values.
        Default: (), the corresponding hash value is ``'f9065fa7'``.
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose : bool
        Whether to print out progress information
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.

    Attributes
    ----------
    url : str
        The URL to download the dataset
    name : str
        The dataset name
    raw_dir : str
        Directory to store all the downloaded raw datasets.
    raw_path : str
        Path to the downloaded raw dataset folder. An alias for
        ``os.path.join(self.raw_dir, self.name)``.
    save_dir : str
        Directory to save all the processed datasets.
    save_path : str
        Path to the processed dataset folder. An alias for
        ``os.path.join(self.save_dir, self.name)``.
    verbose : bool
        Whether to print more runtime information.
    hash : str
        Hash value for the dataset and the setting.
    """

    def __init__(
        self,
        name,
        url=None,
        raw_dir=None,
        save_dir=None,
        hash_key=(),
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        self._name = name
        self._url = url
        self._force_reload = force_reload
        self._verbose = verbose
        self._hash_key = hash_key
        self._hash = self._get_hash()
        self._transform = transform

        # if no dir is provided, the default dgl download dir is used.
        if raw_dir is None:
            self._raw_dir = get_download_dir()
        else:
            self._raw_dir = raw_dir

        if save_dir is None:
            self._save_dir = self._raw_dir
        else:
            self._save_dir = save_dir

        self._load()

    def download(self):
        r"""Overwite to realize your own logic of downloading data.

        It is recommended to download the to the :obj:`self.raw_dir`
        folder. Can be ignored if the dataset is
        already in :obj:`self.raw_dir`.
        """
        pass

    def save(self):
        r"""Overwite to realize your own logic of
        saving the processed dataset into files.

        It is recommended to use ``dgl.data.utils.save_graphs``
        to save dgl graph into files and use
        ``dgl.data.utils.save_info`` to save extra
        information into files.
        """
        pass

    def load(self):
        r"""Overwite to realize your own logic of
        loading the saved dataset from files.

        It is recommended to use ``dgl.data.utils.load_graphs``
        to load dgl graph from files and use
        ``dgl.data.utils.load_info`` to load extra information
        into python dict object.
        """
        pass

    @abc.abstractmethod
    def process(self):
        r"""Overwrite to realize your own logic of processing the input data."""
        pass

    def has_cache(self):
        r"""Overwrite to realize your own logic of
        deciding whether there exists a cached dataset.

        By default False.
        """
        return False

    @retry_method_with_fix(download)
    def _download(self):
        """Download dataset by calling ``self.download()``
        if the dataset does not exists under ``self.raw_path``.

        By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
        One can overwrite ``raw_path()`` function to change the path.
        """
        if os.path.exists(self.raw_path):  # pragma: no cover
            return

        makedirs(self.raw_dir)
        self.download()

    def _load(self):
        """Entry point from __init__ to load the dataset.

        If cache exists:

          - Load the dataset from saved dgl graph and information files.
          - If loadin process fails, re-download and process the dataset.

        else:

          - Download the dataset if needed.
          - Process the dataset and build the dgl graph.
          - Save the processed dataset into files.
        """
        load_flag = not self._force_reload and self.has_cache()

        if load_flag:
            try:
                self.load()
                if self.verbose:
                    print("Done loading data from cached files.")
            except KeyboardInterrupt:
                raise
            except:
                load_flag = False
                if self.verbose:
                    print(traceback.format_exc())
                    print("Loading from cache failed, re-processing.")

        if not load_flag:
            self._download()
            self.process()
            self.save()
            if self.verbose:
                print("Done saving data into cached files.")

    def _get_hash(self):
        """Compute the hash of the input tuple

        Example
        -------
        Assume `self._hash_key = (10, False, True)`

        >>> hash_value = self._get_hash()
        >>> hash_value
        'a770b222'
        """
        hash_func = hashlib.sha1()
        hash_func.update(str(self._hash_key).encode("utf-8"))
        return hash_func.hexdigest()[:8]

    @property
    def url(self):
        r"""Get url to download the raw dataset."""
        return self._url

    @property
    def name(self):
        r"""Name of the dataset."""
        return self._name

    @property
    def raw_dir(self):
        r"""Raw file directory contains the input data folder."""
        return self._raw_dir

    @property
    def raw_path(self):
        r"""Directory contains the input data files.
        By default raw_path = os.path.join(self.raw_dir, self.name)
        """
        return os.path.join(self.raw_dir, self.name)

    @property
    def save_dir(self):
        r"""Directory to save the processed dataset."""
        return self._save_dir

    @property
    def save_path(self):
        r"""Path to save the processed dataset."""
        return os.path.join(self._save_dir, self.name)

    @property
    def verbose(self):
        r"""Whether to print information."""
        return self._verbose

    @property
    def hash(self):
        r"""Hash value for the dataset and the setting."""
        return self._hash

    @abc.abstractmethod
    def __getitem__(self, idx):
        r"""Gets the data object at index."""
        pass

    @abc.abstractmethod
    def __len__(self):
        r"""The number of examples in the dataset."""
        pass

    def __repr__(self):
        return (
            f'Dataset("{self.name}", num_graphs={len(self)},'
            + f" save_path={self.save_path})"
        )


class DGLBuiltinDataset(DGLDataset):
    r"""The Basic DGL Builtin Dataset.

    Parameters
    ----------
    name : str
        Name of the dataset.
    url : str
        Url to download the raw dataset.
    raw_dir : str
        Specifying the directory that will store the
        downloaded data or the directory that
        already stores the input data.
        Default: ~/.dgl/
    hash_key : tuple
        A tuple of values as the input for the hash function.
        Users can distinguish instances (and their caches on the disk)
        from the same dataset class by comparing the hash values.
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose : bool
        Whether to print out progress information. Default: False
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
    """

    def __init__(
        self,
        name,
        url,
        raw_dir=None,
        hash_key=(),
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        super(DGLBuiltinDataset, self).__init__(
            name,
            url=url,
            raw_dir=raw_dir,
            save_dir=None,
            hash_key=hash_key,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    def download(self):
        r"""Automatically download data and extract it."""
        if self.url is not None:
            zip_file_path = os.path.join(self.raw_dir, self.name + ".zip")
            download(self.url, path=zip_file_path)
            extract_archive(zip_file_path, self.raw_path)
