from . import color_mnist
from . import waterbirds
from . import celeba
from . import industrial

# Temporary impl.
def get_ds(name):
    if name == "color_mnist":
        return color_mnist
    elif name == "waterbirds":
        return waterbirds
    elif name == "celeba":
        return celeba
    elif name == "industrial":
        return industrial
    else:
        raise ValueError(f"Not valid name: {name}")
