from torch.utils.data import Dataset
from torchvision import datasets

class CIFAR100(Dataset):
    def __init__(self, root: str, train: bool = True, transform = None, target_transform = None, download: bool = False):
        super().__init__()
        self.root = root
        self.train = train
        self.download = download
        
        self.data = datasets.CIFAR100(
            root=root,
            download=download,
            transform=transform,
            target_transform=target_transform,
            train=train
        )

    def __getitem__(self, index: int):
        return self.data[index]

    def __len__(self):
        return len(self.data)
