from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder


class TinyImageNet(Dataset):
    def __init__(self, root: str, train: bool = True, transform = None, target_transform = None, download: bool = False):
        super().__init__()
        self.root = root
        self.train = train
        self.download = download

        train_dir = "train" if train else "val"

        self.data = ImageFolder(
            root=f'{root}/tiny-64/{train_dir}',
            transform=transform,
            target_transform=target_transform
        )

    def __getitem__(self, index: int):
        return self.data[index]

    def __len__(self):
        return len(self.data)