import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets.cifar as cifar
from antgine.dataset import AbstractDataset


# TODO temporary
_default_train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # https://github.com/kuangliu/pytorch-cifar/blob/bf78d3b8b358c4be7a25f9f9438c842d837801fd/main.py#L35
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
])

_default_test_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    # https://github.com/kuangliu/pytorch-cifar/blob/bf78d3b8b358c4be7a25f9f9438c842d837801fd/main.py#L35
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
])

class CIFAR10(AbstractDataset):
    """
        CIFAR10 dataset class.
    """
    def __init__(self, root: str, batch_size: int,
                 train_transform: transforms.Compose = _default_train_transform,
                 test_transform: transforms.Compose = _default_test_transform,
                 num_workers=8):
        """
        :param str root: Dataset's root directory.
        :param int batch_size: Batch size.
        :param transforms.Compose train_transform: Transform applied to inputs during training.
        :param transforms.Compose test_transform: Transform applied to inputs during testing.
        :param int num_workers: Number of workers launched for loading data.
        """
        super().__init__()
        self._root = root
        self._batch_size = batch_size
        self._train_transform = train_transform
        self._test_transform = test_transform
        self._num_workers = num_workers
        self._train_set = cifar.CIFAR10(self.root, train=True, transform=self.train_transform, download=True)
        self._test_set = cifar.CIFAR10(self.root, train=False, transform=self.test_transform, download=True)

    @property
    def root(self) -> str:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._root

    @property
    def batch_size(self) -> int:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._batch_size

    @property
    def train_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_transform

    @property
    def test_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_transform

    @property
    def train_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_set

    @property
    def test_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_set

    def _loader(self, dataset: data.Dataset, shuffle: bool):
        return data.DataLoader(dataset, batch_size=self.batch_size,
                               shuffle=shuffle, num_workers=self._num_workers)

    def train_loader(self, shuffle=True) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.train_set, shuffle)

    def test_loader(self, shuffle=False) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.test_set, shuffle)
