
import os 
import shutil
import numpy as np
import torch
from torch._C import dtype
import torchvision
from PIL import Image

from dataloaders.utils import *

class iNotMNIST(torch.utils.data.Dataset):
    """ Single task dataset for Split MNIST. """

    def __init__(self, root, classes, task_num, train, transform=None, target_transform=None, download=True):

        self.train = train  # training set or test set
        self.root = root
        self.target_transform = target_transform
        self.transform = transform

        self.processed_folder = os.path.join(root, 'notMNIST')

        self.url = "https://github.com/facebookresearch/Adversarial-Continual-Learning/raw/master/data/notMNIST.zip"
        self.filename = 'notMNIST.zip'
        fpath = os.path.join(root, self.filename)
        if not os.path.isfile(fpath):
            if not download:
               raise RuntimeError('Dataset not found. You can use download=True to download it')
            else:
                print('Downloading from '+self.url)
                download_url(self.url, root, filename=self.filename)

        if not os.path.isfile(self.processed_folder + '/training.pt') or not os.path.isfile(self.processed_folder + '/test.pt'):
            import zipfile
            zip_ref = zipfile.ZipFile(fpath, 'r')
            zip_ref.extractall(root)
            zip_ref.close()

            # Remove __MACOSX dir
            shutil.rmtree(os.path.join(root, '__MACOSX'))

            print('Processing and saving data files in folder {}...'.format(self.processed_folder))

            # Remove broken files
            for split in ['Train', 'Test']:
                data_dir = os.path.join(root, 'notMNIST', split)
                X, Y = [], []
                folders = os.listdir(data_dir)

                for folder in folders:
                    folder_path = os.path.join(data_dir, folder)
                    for ims in os.listdir(folder_path):
                        try:
                            img_path = os.path.join(folder_path, ims)
                            X.append(np.array(Image.open(img_path).convert('L')))
                            Y.append(ord(folder) - 65)  # Folders are A-J so labels will be 0-9
                        except:
                            print("File {}/{} is broken".format(folder, ims))
                data = [torch.Tensor(X), torch.Tensor(Y)]
                fname = 'training.pt' if (split=='Train') else 'test.pt'
                torch.save(data, os.path.join(self.processed_folder, fname))

        if self.train:
            data_file = 'training.pt' #self.training_file
        else:
            data_file = 'test.pt'  # self.test_file
        # Data and targets in original MNIST dataset
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
        
        # MK: Had to divide by 255.0 here to get data in range [0, 1]. 
        # Could be because this class iMNIST inhretis from datasets.MNIST and 
        # that we don't get the dataset as normally from pytorch and pass data_transforms ToTensor() on that call
        self.data = self.data.numpy().astype(np.float32) / 255.0
        self.targets = list(self.targets.numpy())

        self.train = train  # training set or test set
        if not isinstance(classes, list):
            classes = [classes]
        # Create mapping from original label to label within task itself  
        self.class_mapping = {c: i for i, c in enumerate(classes)}
        self.class_indices = {}
        # Get indices in original MNIST of data points for task 
        for cls in classes:
            self.class_indices[self.class_mapping[cls]] = []

        data = []
        targets = []
        tt = []  # task module labels
        # Create new objects for data and targets as well as task label (i.e. tt)
        for i in range(len(self.data)):
            if self.targets[i] in classes:
                data.append(self.data[i])
                targets.append(self.class_mapping[self.targets[i]])
                tt.append(task_num)
                self.class_indices[self.class_mapping[self.targets[i]]].append(i)

        self.data = data.copy()
        self.targets = targets.copy()
        self.tt = tt.copy()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target, tt = self.data[index], int(self.targets[index]), self.tt[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        # MK: img is numpy array because it was set to that type in constructor
        try:
            img = Image.fromarray(img.numpy(), mode='L')
        except:
            pass

        try:
            if self.transform is not None: img = self.transform(img)
        except:
            pass
        try:
            if self.target_transform is not None: tt = self.target_transform(tt)
        except:
            pass

        return img, target, tt


    def __len__(self):
        return len(self.data)

