from .imagenet import ImagenetDataProvider
import torchvision.datasets as datasets
import os
from .al_sampler import CIFAR100_ROOT

__all__ = ["CIFAR100DataProvider"]


class CIFAR100DataProvider(ImagenetDataProvider):

    @staticmethod
    def name():
        return "cifar100"

    @property
    def n_classes(self):
        return 100

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = CIFAR100_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return datasets.CIFAR100(root=CIFAR100_ROOT, transform=_transforms, train=True, download=False)

    def test_dataset(self, _transforms):
        return datasets.CIFAR100(root=CIFAR100_ROOT, transform=_transforms, train=False, download=False)

    @property
    def train_path(self):
        return self.save_path

    @property
    def valid_path(self):
        return self.save_path
