from loader.cifar10_loader import load_images as load_cifar10


def load_data(data_name, data_type='train', augmix=False):
    print('-' * 50)
    print('DATA NAME:', data_name)
    print('DATA TYPE:', data_type)
    print('-' * 50)

    assert data_name in ['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'svhn', 'stl10', 'mnin', 'mini', 'tiny_imagenet','dcase']
    if data_name == 'dcase':
        data_loader = load_cifar10(data_type)

    return data_loader