from abc import ABC, abstractmethod
import torch.utils.data as data
import torchvision.transforms as transforms


class AbstractDataset(ABC):
    """
        Abstract base class for datasets.
    """

    @property
    @abstractmethod
    def root(self) -> str:
        """
        :return: Dataset's root directory.
        :rtype: str
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def batch_size(self) -> int:
        """
        :return: Batch size.
        :rtype: int
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def train_transform(self) -> transforms.Compose:
        """
        :return: Transform applied to inputs during training.
        :rtype: torchvision.transforms.Compose
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def test_transform(self) -> transforms.Compose:
        """
        :return: Transform applied to inputs during testing.
        :rtype: torchvision.transforms.Compose
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def train_set(self) -> data.Dataset:
        """
        :return: Training set object.
        :rtype: torch.utils.data.Dataset
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def test_set(self) -> data.Dataset:
        """
        :return: Testing set object.
        :rtype: torch.utils.data.Dataset
        """
        raise NotImplementedError()

    @abstractmethod
    def train_loader(self, shuffle: bool) -> data.DataLoader:
        """
        :param bool shuffle: Perform shuffling before sampling from dataset.
        :return: Training loader object.
        :rtype: torch.utils.data.DataLoader
        """
        raise NotImplementedError()

    @abstractmethod
    def test_loader(self, shuffle: bool) -> data.DataLoader:
        """
        :param bool shuffle: Perform shuffling before sampling from dataset.
        :return: Testing loader object.
        :rtype: torch.utils.data.DataLoader
        """
        raise NotImplementedError()
