import os
import random

import numpy as np
from PIL import Image
from torchvision.datasets.vision import VisionDataset


class BR35HDataset(VisionDataset):
    label_to_int = {"yes": 1, "no": 0}

    def __init__(self, root, transform=None, target_transform=None, train=True):
        super().__init__(root, transform=transform, target_transform=target_transform)

        if train:
            self.data = [
                os.path.join(self.root, "yes", f"y{i}.jpg") for i in range(1200)
            ] + [os.path.join(self.root, "no", f"no{i}.jpg") for i in range(1200)]
        else:
            self.data = [
                os.path.join(self.root, "yes", f"y{i}.jpg") for i in range(1200, 1499)
            ] + [os.path.join(self.root, "no", f"no{i}.jpg") for i in range(1200, 1499)]

        # Randomly shuffle validation set to not have contiguous labels, which can result in errors during sanity checks
        random.Random(0).shuffle(self.data)
        self.data = np.array(self.data)

        # Load and preprocess the data here
        self.targets = []
        for file in self.data:
            label = int(file.split("/")[-1][0] == "y")
            self.targets.append(label)

    def __getitem__(self, index):
        # Implement the logic to retrieve and preprocess a single data sample here
        path = self.data[index]
        sample = Image.open(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if sample.shape[0] == 4:
            sample = sample[:3]
        label = self.targets[index]
        return sample, label

    def __len__(self):
        # Implement the logic to return the total number of data samples here
        return len(self.data)
