import os
import pickle
import logging
from tqdm import tqdm
import numpy as np

from schnetpack.datasets import QM9
from schnetpack.data import load_dataset, AtomsLoader
from schnetpack import properties


class QM9Filtered(QM9):
    """
    QM9 dataset with a filter on the number of atoms. Only molecules of specific size are loaded.
    """

    def __init__(
        self,
        datapath: str,
        batch_size: int,
        n_atoms_allowed: int = None,
        shuffle_train: bool = True,
        indices_path="n_atoms_indices.pkl",
        n_overfit_molecules=None,
        permute_indices=False,
        **kwargs
    ):
        """

        Args:
            n_atoms_allowed (int): the number of atoms that each molecule should have.
        """
        super().__init__(datapath=datapath, batch_size=batch_size, **kwargs)
        self.n_atoms_allowed = n_atoms_allowed
        self.indices_path = indices_path
        self.shuffle_train = shuffle_train
        self.n_overfit_molecules = n_overfit_molecules
        self.permute_indices = permute_indices

    def setup(self, stage=None):
        if self.data_workdir is None:
            datapath = self.datapath
        else:
            datapath = self._copy_to_workdir()

        # use only molecules with specific number of atoms
        # (re)load datasets
        if self.dataset is None:
            self.dataset = load_dataset(
                datapath,
                self.format,
                property_units=self.property_units,
                distance_unit=self.distance_unit,
                load_properties=self.load_properties,
            )

            if self.n_atoms_allowed is not None and self.n_atoms_allowed > 0:
                if os.path.exists(self.indices_path):
                    with open(self.indices_path, "rb") as file:
                        indices = pickle.load(file)
                else:
                    indices = {}
                if self.n_atoms_allowed in indices.keys():
                    indices = indices[self.n_atoms_allowed]
                else:
                    tmp = []
                    for i in tqdm(range(len(self.dataset))):
                        if self.dataset[i][properties.n_atoms] == self.n_atoms_allowed:
                            tmp.append(i)
                    indices[self.n_atoms_allowed] = tmp
                    with open(self.indices_path, "wb") as file:
                        pickle.dump(indices, file)
                    indices = indices[self.n_atoms_allowed]
            else:
                indices = list(range(len(self.dataset)))

            if self.n_overfit_molecules is not None and self.n_overfit_molecules > 0:
                if self.permute_indices:
                    indices = np.random.permutation(indices).tolist()
                indices = indices[: self.n_overfit_molecules] * (
                    int(len(indices) / self.n_overfit_molecules)
                    + (len(indices) % self.n_overfit_molecules)
                )
                logging.warn(
                    "Overfitting on {} molecules with indices {}".format(
                        self.n_overfit_molecules, indices[: self.n_overfit_molecules]
                    )
                )

            self.dataset = self.dataset.subset(indices)

            # load and generate partitions if needed
            if self.train_idx is None:
                self._load_partitions()

            # partition dataset
            self._train_dataset = self.dataset.subset(self.train_idx)
            self._val_dataset = self.dataset.subset(self.val_idx)
            self._test_dataset = self.dataset.subset(self.test_idx)

        self._setup_transforms()

    def train_dataloader(self) -> AtomsLoader:
        if self._train_dataloader is None:
            self._train_dataloader = AtomsLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=self.shuffle_train,
                pin_memory=self._pin_memory,
            )
        return self._train_dataloader
