# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Tuple, Union, Optional
from overrides import overrides, EnforceOverrides
from torch.utils.data.dataset import Dataset
import torchvision
from torchvision.transforms import transforms
from archai.datasets.dataset_provider import DatasetProvider, register_dataset_provider, TrainTestDatasets
from archai.common.config import Config
[docs]class FashionMnistProvider(DatasetProvider):
def __init__(self, conf_dataset:Config):
super().__init__(conf_dataset)
self._dataroot = conf_dataset['dataroot']
[docs] @overrides
def get_datasets(self, load_train:bool, load_test:bool,
transform_train, transform_test)->TrainTestDatasets:
trainset, testset = None, None
if load_train:
trainset = torchvision.datasets.FashionMNIST(root=self._dataroot,
train=True, download=True, transform=transform_train)
if load_test:
testset = torchvision.datasets.FashionMNIST(root=self._dataroot,
train=False, download=True, transform=transform_test)
return trainset, testset
register_dataset_provider('fashion_mnist', FashionMnistProvider)