import glob
import os
import random

from PIL import Image
import numpy as np
from torchvision.datasets.vision import VisionDataset


class MelanomaDataset(VisionDataset):
    label_to_int = {"benign": 0, "malignant": 1}

    def __init__(self, root, transform=None, target_transform=None, train=True):
        super().__init__(root, transform=transform, target_transform=target_transform)

        if train:
            self.data = glob.glob(os.path.join(self.root, "train", "*", "*.jpg"))
        else:
            self.data = glob.glob(os.path.join(self.root, "test", "*", "*.jpg"))

        # Randomly shuffle validation set to not have contiguous labels, which can result in errors during sanity checks
        random.Random(0).shuffle(self.data)
        self.data = np.array(self.data)

        # Load and preprocess the data here
        self.targets = []
        for file in self.data:
            label = file.split("/")[-2]
            self.targets.append(self.label_to_int[label])

    def __getitem__(self, index):
        # Implement the logic to retrieve and preprocess a single data sample here
        path = self.data[index]
        sample = Image.open(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if sample.shape[0] == 1:
            sample = sample.repeat(3, 1, 1)
        label = self.targets[index]
        return sample, label

    def __len__(self):
        # Implement the logic to return the total number of data samples here
        return len(self.data)
