import logging
import os
import shutil
from copy import copy
from typing import Optional, List, Dict, Tuple, Union


import numpy as np
import fasteners

import pytorch_lightning as pl
import torch

from schnetpack.data import (
    AtomsDataFormat,
    resolve_format,
    load_dataset,
    BaseAtomsData,
    AtomsLoader,
    calculate_stats,
    SplittingStrategy,
    RandomSplit,
)


__all__ = ["AtomsDataModule", "AtomsDataModuleError"]


class AtomsDataModuleError(Exception):
    pass


class AtomsDataModule(pl.LightningDataModule):
    """
    A general ``LightningDataModule`` for SchNetPack datasets.

    """

    def __init__(
        self,
        datapath: str,
        batch_size: int,
        num_train: Union[int, float] = None,
        num_val: Union[int, float] = None,
        num_test: Optional[Union[int, float]] = None,
        split_file: Optional[str] = "split.npz",
        format: Optional[AtomsDataFormat] = None,
        load_properties: Optional[List[str]] = None,
        val_batch_size: Optional[int] = None,
        test_batch_size: Optional[int] = None,
        transforms: Optional[List[torch.nn.Module]] = None,
        train_transforms: Optional[List[torch.nn.Module]] = None,
        val_transforms: Optional[List[torch.nn.Module]] = None,
        test_transforms: Optional[List[torch.nn.Module]] = None,
        num_workers: int = 8,
        num_val_workers: Optional[int] = None,
        num_test_workers: Optional[int] = None,
        property_units: Optional[Dict[str, str]] = None,
        distance_unit: Optional[str] = None,
        data_workdir: Optional[str] = None,
        cleanup_workdir_stage: Optional[str] = "test",
        splitting: Optional[SplittingStrategy] = None,
        pin_memory: Optional[bool] = False,
    ):
        """
        Args:
            datapath: path to dataset
            batch_size: (train) batch size
            num_train: number of training examples (absolute or relative)
            num_val: number of validation examples (absolute or relative)
            num_test: number of test examples (absolute or relative)
            split_file: path to npz file with data partitions
            format: dataset format
            load_properties: subset of properties to load
            val_batch_size: validation batch size. If None, use test_batch_size, then
                batch_size.
            test_batch_size: test batch size. If None, use val_batch_size, then
                batch_size.
            transforms: Preprocessing transform applied to each system separately before
                batching.
            train_transforms: Overrides transform_fn for training.
            val_transforms: Overrides transform_fn for validation.
            test_transforms: Overrides transform_fn for testing.
            num_workers: Number of data loader workers.
            num_val_workers: Number of validation data loader workers
                (overrides num_workers).
            num_test_workers: Number of test data loader workers
                (overrides num_workers).
            property_units: Dictionary from property to corresponding unit as a string
                (eV, kcal/mol, ...).
            distance_unit: Unit of the atom positions and cell as a string
                (Ang, Bohr, ...).
            data_workdir: Copy data here as part of setup, e.g. to a local file
                system for faster performance.
            cleanup_workdir_stage: Determines after which stage to remove the data
                workdir
            splitting: Method to generate train/validation/test partitions
                (default: RandomSplit)
            pin_memory: If true, pin memory of loaded data to GPU. Default: Will be
                set to true, when GPUs are used.
        """
        super().__init__()

        self._train_transforms = train_transforms or copy(transforms) or []
        self._val_transforms = val_transforms or copy(transforms) or []
        self._test_transforms = test_transforms or copy(transforms) or []

        self.batch_size = batch_size
        self.val_batch_size = val_batch_size or test_batch_size or batch_size
        self.test_batch_size = test_batch_size or val_batch_size or batch_size
        self.num_train = num_train
        self.num_val = num_val
        self.num_test = num_test
        self.splitting = splitting or RandomSplit()
        self.split_file = split_file
        self.datapath, self.format = resolve_format(datapath, format)
        self.load_properties = load_properties
        self.num_workers = num_workers
        self.num_val_workers = self.num_workers
        self.num_test_workers = self.num_workers
        if num_val_workers is not None:
            self.num_val_workers = num_val_workers
        if num_test_workers is not None:
            self.num_test_workers = num_test_workers
        self.property_units = property_units
        self.distance_unit = distance_unit
        self._stats = {}
        self._is_setup = False
        self.data_workdir = data_workdir
        self.cleanup_workdir_stage = cleanup_workdir_stage
        self._pin_memory = pin_memory

        self.train_idx = None
        self.val_idx = None
        self.test_idx = None
        self.dataset = None
        self._train_dataset = None
        self._val_dataset = None
        self._test_dataset = None

        self._train_dataloader = None
        self._val_dataloader = None
        self._test_dataloader = None

    @property
    def train_transforms(self):
        """
        Optional transforms (or collection of transforms) you can apply to train
        dataset.
        """
        return self._train_transforms

    @property
    def val_transforms(self):
        """
        Optional transforms (or collection of transforms) you can apply to validation
        dataset.
        """
        return self._val_transforms

    @property
    def test_transforms(self):
        """
        Optional transforms (or collection of transforms) you can apply to test dataset.
        """
        return self._test_transforms

    def setup(self, stage: Optional[str] = None):
        # check whether data needs to be copied
        if self.data_workdir is None:
            datapath = self.datapath
        else:
            datapath = self._copy_to_workdir()

        # (re)load datasets
        if self.dataset is None:
            self.dataset = load_dataset(
                datapath,
                self.format,
                property_units=self.property_units,
                distance_unit=self.distance_unit,
                load_properties=self.load_properties,
            )

            # load and generate partitions if needed
            if self.train_idx is None:
                self._load_partitions()

            # partition dataset
            self._train_dataset = self.dataset.subset(self.train_idx)
            self._val_dataset = self.dataset.subset(self.val_idx)
            self._test_dataset = self.dataset.subset(self.test_idx)
            self._setup_transforms()

    def _copy_to_workdir(self):
        """
        Copies the data to given (fast) working location. Useful for working on cluster
        with slow shared and fast local file systems.

        Returns:
            path to data in workdir
        """
        if not os.path.exists(self.data_workdir):
            os.makedirs(self.data_workdir, exist_ok=True)
        name = self.datapath.split("/")[-1]
        datapath = os.path.join(self.data_workdir, name)
        lock = fasteners.InterProcessLock(
            os.path.join(self.data_workdir, f"dataworkdir_{name}.lock")
        )
        with lock:
            self._log_with_rank("Enter lock")

            # retry reading, in case other process finished in the meantime
            if not os.path.exists(datapath):
                self._log_with_rank("Copy data to data workdir")
                shutil.copy(self.datapath, datapath)

            # reset datasets in case they need to be reloaded
            self.dataset = None
            self._train_dataset = None
            self._val_dataset = None
            self._test_dataset = None

            # reset cleanup
            self._has_teardown_fit = False
            self._has_teardown_val = False
            self._has_teardown_test = False
        self._log_with_rank("Exited lock")
        return datapath

    def teardown(self, stage: Optional[str] = None):
        if self.cleanup_workdir_stage and stage == self.cleanup_workdir_stage:
            if self.data_workdir is not None:
                try:
                    shutil.rmtree(self.data_workdir)
                except:
                    pass
                self._has_setup_fit = False
                self._has_setup_val = False
                self._has_setup_test = False

        # teardown transforms
        for t in self.train_transforms:
            t.teardown()
        for t in self.val_transforms:
            t.teardown()
        for t in self.test_transforms:
            t.teardown()

    def _load_partitions(self):
        # split dataset
        lock = fasteners.InterProcessLock("splitting.lock")

        with lock:
            self._log_with_rank("Enter splitting lock")

            if self.split_file is not None and os.path.exists(self.split_file):
                self._log_with_rank("Load split")

                S = np.load(self.split_file)
                self.train_idx = S["train_idx"].tolist()
                self.val_idx = S["val_idx"].tolist()
                self.test_idx = S["test_idx"].tolist()
                if self.num_train and self.num_train != len(self.train_idx):
                    logging.warning(
                        f"Split file was given, but `num_train ({self.num_train})"
                        + f" != len(train_idx)` ({len(self.train_idx)})!"
                    )
                if self.num_val and self.num_val != len(self.val_idx):
                    logging.warning(
                        f"Split file was given, but `num_val ({self.num_val})"
                        + f" != len(val_idx)` ({len(self.val_idx)})!"
                    )
                if self.num_test and self.num_test != len(self.test_idx):
                    logging.warning(
                        f"Split file was given, but `num_test ({self.num_test})"
                        + f" != len(test_idx)` ({len(self.test_idx)})!"
                    )
            else:
                self._log_with_rank("Create split")

                if not self.num_train or not self.num_val:
                    raise AtomsDataModuleError(
                        "If no `split_file` is given, the sizes of the training and"
                        + " validation partitions need to be set!"
                    )

                self.train_idx, self.val_idx, self.test_idx = self.splitting.split(
                    self.dataset, self.num_train, self.num_val, self.num_test
                )

                if self.split_file is not None:
                    self._log_with_rank("Save split")
                    np.savez(
                        self.split_file,
                        train_idx=self.train_idx,
                        val_idx=self.val_idx,
                        test_idx=self.test_idx,
                    )

        self._log_with_rank("Exit splitting lock")

    def _log_with_rank(self, msg: str):
        if self.trainer is not None:
            logging.debug(
                "Global rank:",
                self.trainer.global_rank,
                ", lokal rank:",
                self.trainer.local_rank,
                " >> ",
                msg,
            )
        else:
            logging.debug(">> ", msg)

    def _setup_transforms(self):
        for t in self.train_transforms:
            t.datamodule(self)
        for t in self.val_transforms:
            t.datamodule(self)
        for t in self.test_transforms:
            t.datamodule(self)
        self._train_dataset.transforms = self.train_transforms
        self._val_dataset.transforms = self.val_transforms
        self._test_dataset.transforms = self.test_transforms

    def get_stats(
        self, property: str, divide_by_atoms: bool, remove_atomref: bool
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        key = (property, divide_by_atoms, remove_atomref)
        if key in self._stats:
            return self._stats[key]

        stats = calculate_stats(
            self.train_dataloader(),
            divide_by_atoms={property: divide_by_atoms},
            atomref=self.train_dataset.atomrefs if remove_atomref else None,
        )[property]
        self._stats[key] = stats
        return stats

    @property
    def train_dataset(self) -> BaseAtomsData:
        return self._train_dataset

    @property
    def val_dataset(self) -> BaseAtomsData:
        return self._val_dataset

    @property
    def test_dataset(self) -> BaseAtomsData:
        return self._test_dataset

    def train_dataloader(self) -> AtomsLoader:
        if self._train_dataloader is None:
            self._train_dataloader = AtomsLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=True,
                pin_memory=self._pin_memory,
            )
        return self._train_dataloader

    def val_dataloader(self) -> AtomsLoader:
        if self._val_dataloader is None:
            self._val_dataloader = AtomsLoader(
                self.val_dataset,
                batch_size=self.val_batch_size,
                num_workers=self.num_val_workers,
                pin_memory=self._pin_memory,
            )
        return self._val_dataloader

    def test_dataloader(self) -> AtomsLoader:
        if self._test_dataloader is None:
            self._test_dataloader = AtomsLoader(
                self.test_dataset,
                batch_size=self.test_batch_size,
                num_workers=self.num_test_workers,
                pin_memory=self._pin_memory,
            )
        return self._test_dataloader
