from torch.utils.data import Subset
from torch.utils.data import ConcatDataset
from PIL import Image
from torchvision.datasets import MNIST, USPS
from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import get_target_label_idx, global_contrast_normalization
import numpy as np
import torchvision.transforms as transforms


class MNIST_INVERT_Dataset(TorchvisionDataset):

    def __init__(self, root: str):
        super().__init__(root)

        transform = transforms.Compose([transforms.CenterCrop(24),transforms.Resize((28, 28)),transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

        target_transform = None
        invert_train_set = MyINVERT(root=self.root, train=True, download=True,
                            transform=transform, target_transform=target_transform)
        mnist_train_set = MyMNIST(root=self.root, train=True, download=True, 
                            transform=transform, target_transform=target_transform)
        self.train_set = ConcatDataset([invert_train_set, mnist_train_set])

        invert_test_set = MyINVERT(root=self.root, train=False, download=True, 
                                transform=transform, target_transform=target_transform)
        mnist_test_set = MyMNIST(root=self.root, train=False, download=True,
                                transform=transform, target_transform=target_transform)
        self.test_set = ConcatDataset([invert_test_set, mnist_test_set])
        self.train_size = len(self.train_set)

class MyMNIST(MNIST):
    """Torchvision MNIST class with patch of __getitem__ method to also return the index of a data sample."""

    def __init__(self, *args, **kwargs):
        super(MyMNIST, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Override the original method of the USPS class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')
        protected_attribute = 1
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index, protected_attribute  # only line changed
        
class MyINVERT(MNIST):
    """Torchvision MNIST class with patch of __getitem__ method to also return the index of a data sample."""

    def __init__(self, *args, **kwargs):
        super(MyINVERT, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Override the original method of the USPS class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])
        img = 255 - img
        
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')
        protected_attribute = 0
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index, protected_attribute  # only line changed

