from __future__ import print_function
from PIL import Image
import torch
import torchvision
import numpy as np
import os.path
import sys

import torch.utils.data as data
from torchvision import datasets, transforms


class iMNIST(datasets.MNIST):
    """ Single task dataset for Split MNIST. """

    def __init__(self, root, classes, task_num, train, transform=None, target_transform=None, download=True):
        super(iMNIST, self).__init__(root, task_num, transform=transform,
                                      target_transform=target_transform, download=download)

        self.train = train  # training set or test set
        self.root = root
        self.target_transform = target_transform
        self.transform = transform
        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        # Data and targets in original MNIST dataset
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
        #self.data = np.array(self.data).astype(np.float32)
        
        # MK: Had to divide by 255.0 here to get data in range [0, 1]. 
        # Could be because this class iMNIST inhretis from datasets.MNIST and 
        # that we don't get the dataset as normally from pytorch and pass data_transforms ToTensor() on that call
        self.data = self.data.numpy().astype(np.float32) / 255.0
        self.targets = list(self.targets.numpy()) #list(np.array(self.targets))

        self.train = train  # training set or test set
        if not isinstance(classes, list):
            classes = [classes]
        # Create mapping from original label to label within task itself  
        self.class_mapping = {c: i for i, c in enumerate(classes)}
        self.class_indices = {}
        # Get indices in original MNIST of data points for task 
        for cls in classes:
            self.class_indices[self.class_mapping[cls]] = []

        data = []
        targets = []
        tt = []  # task module labels
        # Create new objects for data and targets as well as task label (i.e. tt)
        for i in range(len(self.data)):
            if self.targets[i] in classes:
                data.append(self.data[i])
                #targets.append(self.targets[i]) #targets.append(self.class_mapping[self.targets[i]])
                targets.append(int(self.class_mapping[self.targets[i]] + (task_num*len(classes))))
                tt.append(task_num)
                self.class_indices[self.class_mapping[self.targets[i]]].append(i)

        self.data = data.copy()
        self.targets = targets.copy()
        self.tt = tt.copy()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target, tt = self.data[index], int(self.targets[index]), self.tt[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        # MK: img is numpy array because it was set to that type in constructor
        try:
            img = Image.fromarray(img.numpy(), mode='L')
        except:
            pass
        
        try:
            if self.transform is not None: 
                img = self.transform(img)
        except:
            pass
        
        try:
            if self.target_transform is not None: 
                target = self.target_transform(target)
        except:
            pass

        return img, target, tt


    def __len__(self):
        return len(self.data)
