from src.dataloaders.utils import SelectedDataset, MaskDataset


def load_dataset(dataset: str):
    if dataset == 'Forest':
        from src.dataloaders.forest import Forest
        return Forest
    elif dataset == 'P53':
        from src.dataloaders.p53 import P53
        return P53
    elif dataset == 'QSAR':
        from src.dataloaders.qsar import QSAR
        return QSAR
    elif dataset == 'Gisette':
        from src.dataloaders.gisette import Gisette
        return Gisette
    elif dataset == 'Shopping':
        from src.dataloaders.shopping import ShoppingGender
        return ShoppingGender
    else:
        raise FileNotFoundError(f'No dataset named {dataset}')


