Source code for archai.datasets.providers.flower102_provider
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Tuple, Union, Optional
import os
from overrides import overrides, EnforceOverrides
from torch.utils.data.dataset import Dataset
import torchvision
from torchvision.transforms import transforms
from archai.datasets.dataset_provider import DatasetProvider, register_dataset_provider, TrainTestDatasets
from archai.common.config import Config
[docs]class Flower102Provider(DatasetProvider):
def __init__(self, conf_dataset:Config):
super().__init__(conf_dataset)
self._dataroot = conf_dataset['dataroot']
[docs] @overrides
def get_datasets(self, load_train:bool, load_test:bool,
transform_train, transform_test)->TrainTestDatasets:
trainset, testset = None, None
if load_train:
trainpath = os.path.join(self._dataroot, 'flower102', 'train')
trainset = torchvision.datasets.ImageFolder(trainpath, transform=transform_train)
if load_test:
testpath = os.path.join(self._dataroot, 'flower102', 'test')
testset = torchvision.datasets.ImageFolder(testpath, transform=transform_train)
return trainset, testset
[docs] @overrides
def get_transforms(self)->tuple:
# MEAN, STD computed for flower102
MEAN = [0.5190, 0.4101, 0.3274]
STD = [0.2972, 0.2488, 0.2847]
# transformations match that in
# https://github.com/antoyang/NAS-Benchmark/blob/master/DARTS/preproc.py
train_transf = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2)
]
test_transf = [transforms.Resize(256), transforms.CenterCrop(224)]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
train_transform = transforms.Compose(train_transf + normalize)
test_transform = transforms.Compose(test_transf + normalize)
return train_transform, test_transform
register_dataset_provider('flower102', Flower102Provider)