import sys

from deel.utils.yaml_to_params import getParams, getFunctionFromModules
from deel.datasets.toy_dataset import toy_ternary_generator,toy_binary_generator
from deel.datasets.mnist_dataset import mnist_dataset
from deel.datasets.fashion_mnist_dataset import fashion_mnist_dataset,fashion_mnist_dataset_oneclass
from deel.datasets.cifar10_dataset import cifar10_dataset,cifar10_dataset_oneclass
from deel.datasets.kmoons_dataset import kmoons_generator,kmoons_generator_binary
from deel.datasets.celeb_a_dataset import celeba_glass_generator
from deel.datasets.celeb_a_dataset import celeba_glass_generator
from deel.datasets.cat_vs_dog_dataset import cat_vs_dog_dataset


getFunction = getFunctionFromModules(sys.modules[__name__])

def load_dataset(dataset_config):
    global type_to_function
    return getFunction(dataset_config['type'])(**getParams(dataset_config,'params',getFunction))
