from torch.utils.data import Subset
from torch.utils.data import ConcatDataset
from PIL import Image
from torchvision.datasets import MNIST, USPS
from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import get_target_label_idx, global_contrast_normalization
import numpy as np
import torchvision.transforms as transforms


class MNIST_USPS_Dataset(TorchvisionDataset):

    def __init__(self, root: str):
        super().__init__(root)

        transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))])

        usps_train_set = MyUSPS(root=self.root, train=True, download=True,
                            transform=transform, target_transform=None)
        mnist_train_set = MyMNIST(root=self.root, train=True, download=True, 
                            transform=transform, target_transform=None)
        
        self.train_set = ConcatDataset([usps_train_set, mnist_train_set])
        usps_test_set = MyUSPS(root=self.root, train=False, download=True, 
                                transform=transform, target_transform=None)
        mnist_test_set = MyMNIST(root=self.root, train=False, download=True,
                                transform=transform, target_transform=None)

        self.test_set = ConcatDataset([usps_test_set, mnist_test_set])
        self.train_size = len(self.train_set)
class MyMNIST(MNIST):
    """Torchvision MNIST class with patch of __getitem__ method to also return the index of a data sample."""

    def __init__(self, *args, **kwargs):
        super(MyMNIST, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Override the original method of the USPS class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')
        protected_attribute = 1
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
        #print("cur images:", img) 
        return img, target, index, protected_attribute  # only line changed
        
class MyUSPS(USPS):
    """Torchvision MNIST class with patch of __getitem__ method to also return the index of a data sample."""

    def __init__(self, *args, **kwargs):
        super(MyUSPS, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Override the original method of the USPS class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img, mode='L')
        protected_attribute = 0
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index, protected_attribute  # only line changed

