Source code for archai.datasets.providers.imagenet_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List, Tuple, Union, Optional
import os
import shutil
import torch

import torchvision
from torchvision.transforms import transforms
from torchvision import datasets
from torchvision.datasets.utils import check_integrity, download_url

from PIL import Image

from overrides import overrides, EnforceOverrides

from archai.common.common import logger
from archai.datasets.dataset_provider import DatasetProvider, register_dataset_provider, TrainTestDatasets
from archai.common.config import Config
from archai.common import utils
from archai.datasets.transforms.lighting import Lighting
from .imagenet_folder import ImageNetFolder

[docs]class ImagenetProvider(DatasetProvider): def __init__(self, conf_dataset:Config): super().__init__(conf_dataset) self._dataroot = utils.full_path(conf_dataset['dataroot'])
[docs] @overrides def get_datasets(self, load_train:bool, load_test:bool, transform_train, transform_test)->TrainTestDatasets: trainset, testset = None, None if load_train: trainset = datasets.ImageFolder(root=os.path.join(self._dataroot, 'ImageNet', 'train'), transform=transform_train) # compatibility with older PyTorch if not hasattr(trainset, 'targets'): trainset.targets = [lb for _, lb in trainset.samples] if load_test: testset = datasets.ImageFolder(root=os.path.join(self._dataroot, 'ImageNet', 'val'), transform=transform_test) return trainset, testset
[docs] @overrides def get_transforms(self)->tuple: MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] _IMAGENET_PCA = { 'eigval': [0.2175, 0.0188, 0.0045], 'eigvec': [ [-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203], ] } transform_train, transform_test = None, None transform_train = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.08, 1.0), # TODO: these two params are normally not specified interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 ), transforms.ToTensor(), # TODO: Lighting is not used in original darts paper # Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), transforms.Normalize(mean=MEAN, std=STD) ]) transform_test = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) return transform_train, transform_test
register_dataset_provider('imagenet', ImagenetProvider)