from conf.dataset import DatasetParams
from data.BlenderDataset.BlenderDataset import BlenderDataModule
from data.brats2020.bratsloader import BRATS2020DataModule
from data.CelebADataset.CelebADataset import CelebADataModule


def get_dm(params: DatasetParams):
    dataset_name = params.data_params.name
    if dataset_name in ['blender']:
        return BlenderDataModule(params)
    elif dataset_name in ['brats2020']:
        return BRATS2020DataModule(params)
    elif dataset_name in ['celeba']:
        return CelebADataModule(params)
    else:
        raise Exception(f'Dataset type not available: {dataset_name=}')
