from torchvision.datasets.vision import VisionDataset
from torchvision import transforms
import torch
import glob, os
from PIL import Image
import numpy as np
import random


class PneumoniaData(VisionDataset):
    def __init__(self, root, transform=None, target_transform=None, train=True):
        super().__init__(root, transform=transform, target_transform=target_transform)
        print(root)
        if train:
            self.data_normal = glob.glob(
                os.path.join(self.root, "train", "NORMAL", "*.jpeg")
            )
            self.data_pneumonia = glob.glob(
                os.path.join(self.root, "train", "PNEUMONIA", "*.jpeg")
            )
        else:
            self.data_normal = glob.glob(
                os.path.join(self.root, "test", "NORMAL", "*.jpeg")
            )
            self.data_pneumonia = glob.glob(
                os.path.join(self.root, "test", "PNEUMONIA", "*.jpeg")
            )
        self.data = self.data_normal + self.data_pneumonia
        self.labels_normal = torch.zeros(len(self.data_normal))
        self.labels_pneumonia = torch.ones(len(self.data_pneumonia))
        self.labels = torch.cat([self.labels_normal, self.labels_pneumonia], dim=0)
        self.targets = self.labels

        data = list(zip(self.data, list(self.targets)))
        # Randomly shuffle validation set to not have contiguous labels, which can result in errors during sanity checks
        # TODO: should we not put seed instead of 0?
        random.Random(0).shuffle(data)
        self.data, self.targets = zip(*data)
        self.data = np.array(self.data)

    def __getitem__(self, index):
        # Implement the logic to retrieve and preprocess a single data sample here
        path = self.data[index]
        sample = Image.open(path)
        if self.transform is not None:
            sample = self.transform(sample)
        label = self.targets[index]
        return sample, label

    def __len__(self):
        # Implement the logic to return the total number of data samples here
        return len(self.data)
