from mhvae_vasco.objective.features.objective import ObjectiveFeatures
from mhvae_vasco.objective.images.objective import ObjectiveImages


def get_objective(dataset_name):
    if dataset_name == 'cub_ft':
        objective = ObjectiveFeatures()
    elif dataset_name == 'flowers':
        objective = ObjectiveImages()
    else:
        raise ValueError
    return objective
