from mhvae_vasco.model.features.model import MhvaeFeatures
from mhvae_vasco.model.images.model import MhvaeImages


def get_model(args, device):
    if args.dset_name == 'cub_ft':
        model = MhvaeFeatures(args, device)
    elif args.dset_name == 'flowers':
        model = MhvaeImages(args, device)
    else:
        raise ValueError
    return model