from torchvision.datasets.vision import VisionDataset
import glob, os
from PIL import Image
import numpy as np
import random


class DogsVsCatsData(VisionDataset):

    label_to_int = {"cat": 0, "dog": 1}

    def __init__(self, root, transform=None, target_transform=None, train=True):
        super().__init__(root, transform=transform, target_transform=target_transform)

        if train:
            self.data = [
                os.path.join(self.root, "train", f"dog.{i}.jpg") for i in range(10000)
            ] + [os.path.join(self.root, "train", f"cat.{i}.jpg") for i in range(10000)]
        else:
            self.data = [
                os.path.join(self.root, "train", f"dog.{i}.jpg")
                for i in range(10001, 12500)
            ] + [
                os.path.join(self.root, "train", f"cat.{i}.jpg")
                for i in range(10001, 12500)
            ]
        # Randomly shuffle validation set to not have contiguous labels, which can result in errors during sanity checks
        random.Random(0).shuffle(self.data)
        self.data = np.array(self.data)

        # Load and preprocess the data here
        self.targets = []
        for file in self.data:
            label = file.split("/")[-1].split(".")[0]
            self.targets.append(self.label_to_int[label])

    def __getitem__(self, index):
        # Implement the logic to retrieve and preprocess a single data sample here
        path = self.data[index]
        sample = Image.open(path)
        if self.transform is not None:
            sample = self.transform(sample)
        label = self.targets[index]
        return sample, label

    def __len__(self):
        # Implement the logic to return the total number of data samples here
        return len(self.data)


#
