from .imagenet import ImagenetDataProvider
import torchvision.datasets as datasets
import os
from .al_sampler import CIFAR10_ROOT

__all__ = ["CIFAR10DataProvider"]


class CIFAR10DataProvider(ImagenetDataProvider):

    @staticmethod
    def name():
        return "cifar10"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = CIFAR10_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return datasets.CIFAR10(root=CIFAR10_ROOT, transform=_transforms, train=True, download=False)

    def test_dataset(self, _transforms):
        return datasets.CIFAR10(root=CIFAR10_ROOT, transform=_transforms, train=False, download=False)

    @property
    def train_path(self):
        return self.save_path

    @property
    def valid_path(self):
        return self.save_path
