import os

from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms as T

TRANSFORM = T.Compose([T.Resize((32, 32)), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])


class Cifar100Dataset(Dataset):
    def __init__(self, transform):
        self.cifar_training_dataset = datasets.CIFAR100("./data_src", train=True, download=True, transform=transform)

    def __len__(self):
        return len(self.cifar_training_dataset)

    def __getitem__(self, index):
        return self.cifar_training_dataset[index][0]


class TinyImageNetDataset(Dataset):
    def __init__(
        self,
        dir_path="/net/tscratch/people/plglukaszst/projects/diffusion-arithmetics/data_src/data_src/tiny-imagenet-200/test/images",
        transform=TRANSFORM,
    ):
        self.dir_path = dir_path
        self.files = os.listdir(self.dir_path)
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        image = Image.open(os.path.join(self.dir_path, self.files[index])).convert("RGB")
        image = self.transform(image)
        return image
