from torchvision import datasets


class Pets(datasets.OxfordIIITPet):
    def __init__(
        self, root, split="trainval", transform=None, target_transform=None, target="category"
    ):
        super().__init__(
            root=root,
            split=split,
            transform=transform,
            download=True,
            target_types=[target, "segmentation", "category"],
            target_transform=target_transform,
        )

    def __getitem__(self, index: int):
        image, label = super().__getitem__(index)
        return {
            "x": image,
            "y": label[0],
            "segmentation": label[1] * 255 == 1,
            "category": label[2],
        }
