from typing import Dict, List, Callable, Union, Optional
import numpy as np
import logging

import torch

from nequip.data import AtomicData
from nequip.utils.savenload import atomic_write
from nequip.data.transforms import TypeMapper
from nequip.data import AtomicDataset


class ExampleCustomDataset(AtomicDataset):
    """
    See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets.

    If you don't need downloading or pre-processing, just don't define any of the relevant methods/properties.
    """

    def __init__(
        self,
        root: str,
        custom_option1,
        custom_option2="default",
        type_mapper: Optional[TypeMapper] = None,
    ):
        # Initialize the AtomicDataset, which runs .download() (if present) and .process()
        # See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets
        # This will only run download and preprocessing if cached dataset files aren't found
        super().__init__(root=root, type_mapper=type_mapper)

        # if the processed paths don't exist, `self.process()` has been called at this point
        # (if it is defined)
        # but otherwise you need to load the data from the cached pre-processed dir:
        if self.mydata is None:
            self.mydata = torch.load(self.processed_paths[0])
        # if you didn't define `process()`, this is where you would unconditionally load your data.

    def len(self) -> int:
        """Return the number of frames in the dataset."""
        return 42

    @property
    def raw_file_names(self) -> List[str]:
        """Return a list of filenames for the raw data.

        Need to be simple filenames to be looked for in `self.raw_dir`
        """
        return ["data.dat"]

    @property
    def raw_dir(self) -> str:
        return "/path/to/dataset-folder/"

    @property
    def processed_file_names(self) -> List[str]:
        """Like `self.raw_file_names`, but for the files generated by `self.process()`.

        Should not be paths, just file names. These will be stored in `self.processed_dir`,
        which is set by NequIP in `AtomicDataset` based on `self.root` and a hash of the
        dataset options provided to `__init__`.
        """
        return ["processed-data.pth"]

    # def download(self):
    #     """Optional method to download raw data before preprocessing if the `raw_paths` do not exist."""
    #     pass

    def process(self):
        # load things from the raw data:
        # whatever is appropriate for your format
        data = np.load(self.raw_dir + "/" + self.raw_file_names[0])

        # if any pre-processing is necessary, do it and cache the results to
        # `self.processed_paths` as you defined above:
        with atomic_write(self.processed_paths[0], binary=True) as f:
            # e.g., anything that takes a file `f` will work
            torch.save(data, f)
            # ^ use atomic writes to avoid race conditions between
            # different trainings that use the same dataset
            # since those separate trainings should all produce the same results,
            # it doesn't matter if they overwrite each others cached'
            # datasets. It only matters that they don't simultaneously try
            # to write the _same_ file, corrupting it.

        logging.info("Cached processed data to disk")

        # optionally, save the processed data on the Dataset object
        # to avoid a roundtrip from disk in `__init__` (see above)
        self.mydata = data

    def get(self, idx: int) -> AtomicData:
        """Return the data frame with a given index as an `AtomicData` object."""
        build_an_AtomicData_here = None
        return build_an_AtomicData_here

    def statistics(
        self,
        fields: List[Union[str, Callable]],
        modes: List[str],
        stride: int = 1,
        unbiased: bool = True,
        kwargs: Optional[Dict[str, dict]] = {},
    ) -> List[tuple]:
        """Optional method to compute statistics over an entire dataset.

        This must correctly handle `self._indices` for subsets!!!

        If not provided, options like `avg_num_neighbors: auto`, `per_species_rescale_scales: dataset_*`,
        and others that compute dataset statistics will not work. This only needs to support the statistics
        modes that are necessary for what you need to run (i.e. if you do not use `dataset_per_species_*`
        statistics, you do not need to implement them).

        See `AtomicInMemoryDataset` for full documentation and example implementation.
        """
        raise NotImplementedError
