import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision.datasets import ImageFolder
from torchvision.datasets.vision import VisionDataset


class KvasirDataset(VisionDataset):

    def __init__(self, root, transform=None, target_transform=None, train=True):
        super().__init__(root, transform=transform, target_transform=target_transform)
        image_folder = ImageFolder(self.root, transform=self.transform)
        self.image_folder = image_folder
        self.data = np.array([s[0] for s in image_folder.samples])
        self.targets = np.array([s[1] for s in image_folder.samples])
        self.label_to_int = image_folder.class_to_idx
        train_idx, test_idx = train_test_split(
            np.arange(self.data.shape[0]), test_size=0.3, random_state=42
        )
        if train:
            self.data = self.data[train_idx]
            self.targets = self.targets[train_idx]
        else:
            self.data = self.data[test_idx]
            self.targets = self.targets[test_idx]
        self.targets = list(map(lambda x: 0 if x in [3, 4, 5] else 1, self.targets))

    def __getitem__(self, index):
        # Implement the logic to retrieve and preprocess a single data sample here
        path = self.data[index]
        sample = Image.open(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if sample.shape[0] == 4:
            sample = sample[:3]
        label = self.targets[index]
        return sample, label

    def __len__(self):
        # Implement the logic to return the total number of data samples here
        return len(self.data)
