import os
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from antgine.dataset import AbstractDataset


class ImageFolder(AbstractDataset):
    """
        ImageFolder dataset class.
    """
    def __init__(self, root: str, train_path: str, test_path: str,
                 batch_size: int, train_transform: transforms.Compose,
                 test_transform: transforms.Compose, num_workers=8):
        """
        :param str root: Dataset's root directory.
        :param str train_path: Path of training set under root directory.
        :param str test_path: Path of test set under root directory.
        :param int batch_size: Batch size.
        :param transforms.Compose train_transform: Transform applied to inputs during training.
        :param transforms.Compose test_transform: Transform applied to inputs during testing.
        :param int num_workers: Number of workers launched for loading data.
        """
        super().__init__()
        self._root = root
        self._batch_size = batch_size
        self._train_transform = train_transform
        self._test_transform = test_transform
        self._num_workers = num_workers
        self._train_set = datasets.ImageFolder(os.path.join(self._root, train_path), transform=self.train_transform)
        self._test_set = datasets.ImageFolder(os.path.join(self._root, test_path), transform=self.test_transform)

    @property
    def root(self) -> str:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._root

    @property
    def batch_size(self) -> int:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._batch_size

    @property
    def train_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_transform

    @property
    def test_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_transform

    @property
    def train_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_set

    @property
    def test_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_set

    def _loader(self, dataset: data.Dataset, shuffle: bool):
        return data.DataLoader(dataset, batch_size=self.batch_size,
                               shuffle=shuffle, num_workers=self._num_workers)

    def train_loader(self, shuffle=True) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.train_set, shuffle)

    def test_loader(self, shuffle=False) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.test_set, shuffle)
