import logging

from data.cub.main_ft import load_cub_ft_data
from data.cub.main_raw import load_cub_data
from data.flowers.main_ft import load_flowers_ft_data
from data.flowers.main_raw import load_flowers_data
from data.synthetic import load_synthetic_data

logger = logging.getLogger('custom')


def load_data(mode, args):
    name = args.dset_name
    batch_size = args.sizes['bs']

    if mode == 'test':
        logger.info('!!!\nUSING TEST DATA\n!!!')

    if name == 'synthetic_data':
        dataset, loader = load_synthetic_data(batch_size=batch_size,
                                              mode=mode)
    elif name == 'cub_ft':
        dataset, loader = load_cub_ft_data(mode=mode,
                                           batch_size=batch_size)
    elif name == 'flowers':
        dataset, loader = load_flowers_data(mode=mode,
                                            batch_size=batch_size)
    elif name == 'flowers_ft':
        dataset, loader = load_flowers_ft_data(mode=mode,
                                               batch_size=batch_size)
    elif name == 'cub':
        dataset, loader = load_cub_data(mode=mode,
                                        batch_size=batch_size)
    else:
        raise ValueError(f'{name} is not a legal dataset name.')

    return dataset, loader
