# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Tuple, Type, Optional
from abc import abstractmethod
from overrides import overrides, EnforceOverrides
from torch.utils.data.dataset import Dataset
from torchvision.transforms import transforms
from ..common.config import Config
TrainTestDatasets = Tuple[Optional[Dataset], Optional[Dataset]]
[docs]class DatasetProvider(EnforceOverrides):
def __init__(self, conf_dataset:Config):
super().__init__()
pass
[docs] @abstractmethod
def get_datasets(self, load_train:bool, load_test:bool,
transform_train, transform_test)->TrainTestDatasets:
pass
DatasetProviderType = type(DatasetProvider)
_providers: Dict[str, DatasetProviderType] = {}
[docs]def register_dataset_provider(name:str, class_type:DatasetProviderType)->None:
global _providers
if name in _providers:
raise KeyError(f'dataset provider with name {name} has already been registered')
_providers[name] = class_type
[docs]def get_provider_type(name:str)->DatasetProviderType:
global _providers
return _providers[name]