from collections.abc import Callable


def build_dataset(
    name: str,
    split: str,
    root: str = "./data",
    transform: Callable | None = None,
):
    if name == "celeba":
        from .celeba import celeba
        return celeba(root=root, split=split, transform=transform)

    elif name == "color_mnist":
        from .color_mnist import color_mnist
        return color_mnist(root=root, split=split, transform=transform)

    elif name == "waterbirds":
        from .waterbirds import waterbirds
        return waterbirds(root=root, split=split, transform=transform)

    elif name.startswith("mvtec:"):
        from .mvtec import mvtec
        category = name.split(":")[1]
        return mvtec(category=category, root=root, split=split, transform=transform)

    else:
        raise ValueError(f"Unknown dataset type '{name}'")
