from medmnist import OCTMNIST, BreastMNIST
from PIL import Image
from torchvision.datasets.vision import VisionDataset


class MedMNISTDataset(VisionDataset):
    label_to_int = {"benign": 0, "malignant": 1}

    def __init__(self, root, transform=None, target_transform=None, train=True):
        super().__init__(root, transform=transform, target_transform=target_transform)
        kwargs = dict(
            root=self.root,
            # download=True,
            transform=transform,
            split="train" if train else "test",
            size=224,
        )
        self.dataset = self.MEDMNIST_CLASS(**kwargs)
        self.data = self.dataset.imgs
        self.targets = self.dataset.labels.squeeze()

    def __getitem__(self, index):
        # Implement the logic to retrieve and preprocess a single data sample here
        img, target = self.data[index], int(self.targets[index])
        img = Image.fromarray(img)

        if self.dataset.as_rgb:
            img = img.convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        # Implement the logic to return the total number of data samples here
        return len(self.data)


class BreastMNISTDataset(MedMNISTDataset):
    MEDMNIST_CLASS = BreastMNIST

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.targets = 1 - self.targets


class OctMNISTDataset(MedMNISTDataset):
    MEDMNIST_CLASS = OCTMNIST

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.targets[self.targets < 3] = 1
        self.targets[self.targets == 3] = 0
