from collections.abc import Callable

from torchvision.datasets import CelebA


def celeba(
    root: str = "./data",
    split: str = "train",
    transform: Callable | None = None,
):
    return CelebA(root=root, split=split, transform=transform, target_type="attr")
