from src.DiffusionModel_all_to_all import DiffusionAlltoAll as DiffusionAllToAll


def get_model_class(name: str):
    if name == 'diffusion_all_to_all':
        return DiffusionAllToAll
    else:
        raise ValueError(f'Unknown model name {name}')
