import torch.utils.data as data
import torchvision.transforms as transforms
import antgine
from antgine.dataset import AbstractDataset


_default_train_transform = transforms.Compose([
    transforms.Resize(112),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[127.5/255.] * 3, std=[128.0/255.] * 3),
])

_default_test_transform = transforms.Compose([
    transforms.Resize(112),
    transforms.ToTensor(),
    transforms.Normalize(mean=[127.5/255.] * 3, std=[128.0/255.] * 3),
])


class HDF5ImageFolder(AbstractDataset):
    """
        HDF5ImageFolder dataset class.
    """

    def __init__(self, root: str, train_path: str, test_path: str, batch_size: int,
                 train_transform : transforms.Compose = _default_train_transform,
                 test_transform : transforms.Compose =_default_test_transform, num_workers=8):
        """

        :param str root: Dataset's root directory.
        :param int batch_size: Batch size.
        :param str train_path: HDF5's file for training dataset.
        :param str test_path: HDF5's file for training dataset.
        :param transforms.Compose train_transform: Transform applied to inputs during training.
        :param transforms.Compose test_transform: Transform applied to inputs during testing.
        :param int num_workers: Number of workers launched for loading data.
        """
        super().__init__()
        self._root = root
        self._train_path = train_path
        self._test_path = test_path
        self._batch_size = batch_size
        self._train_transform = train_transform
        self._test_transform = test_transform
        self._num_workers = num_workers

        self._train_set = antgine.datasets.HDF5Dataset(self._train_path, self._train_transform)
        self._test_set = antgine.datasets.HDF5Dataset(self._test_path, self._test_transform)

    @property
    def root(self) -> str:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._root

    @property
    def batch_size(self) -> int:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._batch_size

    @property
    def train_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_transform

    @property
    def test_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_transform

    @property
    def train_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_set

    @property
    def test_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_set

    def _loader(self, dataset: data.Dataset, shuffle: bool):
        return data.DataLoader(dataset, batch_size=self.batch_size,
                               shuffle=shuffle, num_workers=self._num_workers)

    def train_loader(self, shuffle=True) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.train_set, shuffle)

    def test_loader(self, shuffle=False) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.test_set, shuffle)
