Source code for archai.datasets.providers.food101_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List, Tuple, Union, Optional
import os

from overrides import overrides, EnforceOverrides
from torch.utils.data.dataset import Dataset

import torchvision
from torchvision.transforms import transforms

from archai.datasets.dataset_provider import DatasetProvider, register_dataset_provider, TrainTestDatasets
from archai.common.config import Config


[docs]class Food101Provider(DatasetProvider): def __init__(self, conf_dataset:Config): super().__init__(conf_dataset) self._dataroot = conf_dataset['dataroot']
[docs] @overrides def get_datasets(self, load_train:bool, load_test:bool, transform_train, transform_test)->TrainTestDatasets: trainset, testset = None, None if load_train: trainpath = os.path.join(self._dataroot, 'food-101', 'train') trainset = torchvision.datasets.ImageFolder(trainpath, transform=transform_train) if load_test: testpath = os.path.join(self._dataroot, 'food-101', 'test') testset = torchvision.datasets.ImageFolder(testpath, transform=transform_train) return trainset, testset
[docs] @overrides def get_transforms(self)->tuple: # TODO: Need to rethink the food101 transforms MEAN = [0.5451, 0.4435, 0.3436] STD = [0.2171, 0.2251, 0.2260] # TODO: should be [0.2517, 0.2521, 0.2573] train_transf = [ transforms.Resize((32,32)), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] # food101 has images of varying sizes and are ~512 each side test_transf = [transforms.Resize((32,32))] normalize = [ transforms.ToTensor(), transforms.Normalize(MEAN, STD) ] train_transform = transforms.Compose(train_transf + normalize) test_transform = transforms.Compose(test_transf + normalize) return train_transform, test_transform
register_dataset_provider('food101', Food101Provider)