from data.cub.main_ft import load_cub_ft_data
from data.flowers.main_raw import load_flowers_data


def load_data(mode, args) -> (dict, dict):
    if args.dset_name == 'cub_ft':
        dataset, loader = load_cub_ft_data(mode=mode,
                                           batch_size=args.bs)
    elif args.dset_name == 'flowers':
        dataset, loader = load_flowers_data(mode=mode,
                                            batch_size=args.bs)
    else:
        raise ValueError(f'{args.dset_name} is not a legal dataset name for this model.')

    return dataset, loader
