
from __future__ import print_function
import torch
import numpy as np
import os.path
import sys
import random

import torchvision 
from torchvision import datasets, transforms
from torch.utils.data import Subset
import torchvision.transforms.functional as TF

class CLDataset(object):
    """docstring for DatasetGen"""

    def __init__(self, args, dataset):
        super(CLDataset, self).__init__()

        #self.seed = seed
        self.verbose = args.verbose
        self.batch_size = args.cl.batch_size
        self.pc_valid = args.pc_valid
        self.root = args.data_dir
        self.dataset = dataset

        #print('in dataset, seed: %d' %(self.seed))

        self.n_tasks = args.n_tasks
        mean, std, classes, img_size, in_channels = get_statistics(self.dataset)
        self.n_classes = classes
        self.input_size = [in_channels, img_size, img_size]

        self.taskcla = [[t, int(self.n_classes/self.n_tasks)] for t in range(self.n_tasks)]

        # get default classes/task for dataset
        if self.dataset in ['MNIST', 'FashionMNIST', 'notMNIST', 'CIFAR10']:
            self.classes_per_task = 2 # if 5-task Split-dataset 
        elif self.dataset in ['CIFAR100']:
            self.classes_per_task = args.classes_per_task # if 20-task Split-CIFAR100
        elif self.dataset in ['RotatedMNIST']:
            self.classes_per_task = 10 # same number as in original MNIST 

        if self.dataset in ['MNIST', 'FashionMNIST', 'notMNIST']:
            self.task_ids = [[0,1], [2,3], [4,5], [6,7], [8,9]]
        elif self.dataset in ['CIFAR100', ]:#'RotatedMNIST']:
            labels = list(range(self.n_classes))
            self.task_ids = [labels[x:x+self.classes_per_task] for x in range(0, len(labels), self.classes_per_task)] #list(np.split(range(self.n_classes), self.classes_per_task))
        elif self.dataset in ['RotatedMNIST']:
            #nc = self.classes_per_task
            self.task_ids = [list(range(self.classes_per_task)) for x in range(self.n_tasks)] 
            #print(self.task_ids)
        else:
            ValueError('Default task ids not implemented for dataset {:s}'.format(self.dataset))

        self.transformation = transforms.Compose([transforms.ToTensor(),])
        if args.normalize_data:
            self.transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

        if self.dataset in ['RotatedMNIST']:
            nc = self.classes_per_task
            self.transformation = []
            self.target_transformation = []
            per_task_rotation = 180.0/self.n_tasks
            for t, labels in enumerate(self.task_ids):
                rotation_degree = t * per_task_rotation
                #print('rotation_degree: ', rotation_degree)
                transform = torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    RotationTransform(rotation_degree),
                ])
                self.transformation.append(transform)
                self.target_transformation.append(transforms.Lambda(lambda y, x=nc*t: y - x))


        self.indices = {}
        self.dataloaders = {}
        self.idx={}

        self.num_workers = args.num_workers
        self.pin_memory = args.pin_memory

        self.train_set = {}
        self.test_set = {}

        self.train_split = {}
        self.valid_split = {}
        self.valid_set = {}

        self.make_train_val_split()

    def get_dataset_name(self):
        dataset_name = '{}-5Split'.format(self.dataset)
        for i in range(len(self.task_ids)):
            dataset_name += '-{}'.format(self.task_ids[i])
        #dataset_name += '-Seed-{}'.format(self.seed)
        return dataset_name

    def make_train_val_split(self,):
        
        self.datasets = {}
        for task_id in range(self.n_tasks):

            self.dataloaders[task_id] = {}
            sys.stdout.flush()

            if task_id == 0:
                download = True 
            else:
                download = False
           # reduce number of examples to 10k
            if self.dataset in ['RotatedMNIST']:
                self.train_set[task_id] = get_task_dataset("MNIST", root=self.root, classes=self.task_ids[task_id], 
                                                task_num=task_id, train=True, download=download, 
                                                transform=self.transformation[task_id], target_transform=self.target_transformation[task_id])
                self.test_set[task_id] = get_task_dataset("MNIST", root=self.root, classes=self.task_ids[task_id], 
                                                task_num=task_id, train=False, download=download, 
                                                transform=self.transformation[task_id], target_transform=self.target_transformation[task_id])

                train_dataset = self.train_set[task_id]
                ind = np.random.permutation(len(train_dataset))[:10000]
                self.train_set[task_id] = Subset(train_dataset, indices=ind)
            else:
                self.train_set[task_id] = get_task_dataset(self.dataset, root=self.root, classes=self.task_ids[task_id], task_num=task_id, 
                                                train=True, download=download, transform=self.transformation)
                self.test_set[task_id] = get_task_dataset(self.dataset, root=self.root, classes=self.task_ids[task_id], task_num=task_id,
                                                train=False, download=download, transform=self.transformation)

 

            split = int(np.floor(self.pc_valid * len(self.train_set[task_id])))
            train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], [len(self.train_set[task_id]) - split, split])
            self.train_set[task_id] = train_split
            self.valid_set[task_id] = valid_split

            if self.verbose > 0:
                print('Info on {} T{}'.format(self.dataset, task_id))
                print ("Training set size:      {}  images of {}x{}".format(len(train_split), self.input_size[1], self.input_size[1]))
                print ("Validation set size:    {}  images of {}x{}".format(len(valid_split), self.input_size[1], self.input_size[1]))
                print ("Test set size:          {}  images of {}x{}".format(len(self.test_set[task_id]), self.input_size[1], self.input_size[1]))
                print()
            
            

    def get_dataset_for_task(self, task_id):
        # task_id = 0, 1, ..., n_tasks-1
        task_dataset = {}
        task_dataset['train'] = self.train_set[task_id]
        task_dataset['valid'] = self.valid_set[task_id]
        task_dataset['test'] = self.test_set[task_id]
        task_dataset['name'] = '5Split-{}-{}-{}'.format(self.dataset, task_id, self.task_ids[task_id])
        return task_dataset

    """
    def get_dataloader(self, task_id, shuffle=True):
        assert 0 <= task_id and task_id < self.n_tasks, "Task id {} must be within range (0, {})".format(task_id, self.n_tasks-1)
        #return self.dataloaders[task_id]
        #
        #print('in get_dataloader, seed: %d' %(self.seed))
        dataloaders = {}
        train_loader = torch.utils.data.DataLoader(self.train_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers,
                                                pin_memory=self.pin_memory, drop_last=False, shuffle=shuffle,
                                                generator=torch.Generator().manual_seed(self.seed))
        valid_loader = torch.utils.data.DataLoader(self.valid_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers, 
                                                pin_memory=self.pin_memory, drop_last=False, shuffle=False)
        test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers,
                                                pin_memory=self.pin_memory, drop_last=False, shuffle=False)

        dataloaders['train'] = train_loader
        dataloaders['valid'] = valid_loader
        dataloaders['test'] = test_loader
        dataloaders['name'] = '5Split-{}-{}-{}'.format(self.dataset, task_id, self.task_ids[task_id])
        return dataloaders
    """

class ShuffledLabelsCLDataset(CLDataset):

    def __init__(self, args, dataset, seed):
        super().__init__(args, dataset)
        self.seed = seed
        # Shuffle the task ids
        if self.dataset in ['MNIST', 'FashionMNIST', 'notMNIST', 'CIFAR10']:
            labels = torch.randperm(self.n_classes, generator=torch.Generator().manual_seed(self.seed)).tolist()
            n = self.classes_per_task
            self.task_ids = [labels[i:i + n] for i in range(0, len(labels), n)]
            self.class_mapping = {i:c for i, c in enumerate(labels)}
            #print('shuffled class mapping: ', self.class_mapping)
        else:
            ValueError('Shuffling task ids not implemented for dataset {:s}'.format(self.dataset))
        
        # Re-do the initialization of the task datasets
        self.indices = {}
        self.dataloaders = {}
        self.idx={}

        self.train_set = {}
        self.test_set = {}
        self.valid_set = {}

        self.train_split = {}
        self.valid_split = {}

        self.make_train_val_split()
        #self._get_dataloaders()

    def get_dataset_name(self):
        dataset_name = '{}-5Split'.format(self.dataset)
        for i in range(len(self.task_ids)):
            dataset_name += '-{}'.format(self.task_ids[i])
        dataset_name += '-ShuffleSeed-{}'.format(self.seed)
        return dataset_name

class RotationTransform:
    """
    Rotation transforms for the images in `Rotation MNIST` dataset.
    """
    def __init__(self, angle):
        self.angle = angle
    
    def __call__(self, x):
        return TF.rotate(x, self.angle, fill=(0,))

def get_task_dataset(dataset, root, classes, task_num, train, transform, target_transform=None, download=True):

    if dataset == 'MNIST':
        from dataloaders.mnist import iMNIST
        return iMNIST(root=root, classes=classes, task_num=task_num, train=train,  
                        transform=transform, target_transform=target_transform, download=download)

    elif dataset == 'FashionMNIST':
        from dataloaders.fashionmnist import iFashionMNIST
        return iFashionMNIST(root=root, classes=classes, task_num=task_num, train=train,  
                        transform=transform, target_transform=target_transform, download=download)

    elif dataset == 'notMNIST':
        from dataloaders.notmnist import iNotMNIST
        return iNotMNIST(root=root, classes=classes, task_num=task_num, train=train,  
                        transform=transform, target_transform=target_transform, download=download)

    elif dataset == 'CIFAR100':
        from dataloaders.cifar import iCIFAR100_new
        return iCIFAR100_new(root=root, classes=classes, task_num=task_num, train=train,  
                        transform=transform, target_transform=target_transform, download=download)

    else:
        raise NotImplementedError


# Added notMNIST stats from https://github.com/facebookresearch/Adversarial-Continual-Learning/blob/master/src/dataloaders/mulitidatasets.py
def get_statistics(dataset):
    '''
    Returns statistics of the dataset given a string of dataset name. To add new dataset, please add required statistics here
    '''
    assert(dataset in ['MNIST', 'RotatedMNIST', 'FashionMNIST', 'notMNIST', 'SVHN', 'CIFAR10', 'CIFAR100', 'CINIC10', 'ImageNet100', 'ImageNet', 'TinyImagenet'])
    mean = {
            'MNIST':(0.1307,),
            'RotatedMNIST':(0.1307,),
            'FashionMNIST':(0.1307,),
            'notMNIST': (0.4254,),
            'SVHN':  (0.4377,  0.4438,  0.4728),
            'CIFAR10':(0.4914, 0.4822, 0.4465),
            'CIFAR100':(0.5071, 0.4867, 0.4408),
            'CINIC10':(0.47889522, 0.47227842, 0.43047404),
            'TinyImagenet':(0.4802, 0.4481, 0.3975),
            'ImageNet100':(0.485, 0.456, 0.406),
            'ImageNet':(0.485, 0.456, 0.406),
        }

    std = {
            'MNIST':(0.3081,),
            'RotatedMNIST':(0.3081,),
            'FashionMNIST':(0.3081,),
            'notMNIST': (0.4501,),
            'SVHN': (0.1969,  0.1999,  0.1958),
            'CIFAR10':(0.2023, 0.1994, 0.2010),
            'CIFAR100':(0.2675, 0.2565, 0.2761),
            'CINIC10':(0.24205776, 0.23828046, 0.25874835),
            'TinyImagenet':(0.2302, 0.2265, 0.2262),
            'ImageNet100':(0.229, 0.224, 0.225),
            'ImageNet':(0.229, 0.224, 0.225),
        }

    classes = {
            'MNIST': 10,
            'RotatedMNIST': 10,
            'FashionMNIST': 10,
            'notMNIST': 10,
            'SVHN': 10,
            'CIFAR10': 10,
            'CIFAR100': 100,
            'CINIC10': 10,
            'TinyImagenet':200,
            'ImageNet100':100,
            'ImageNet': 1000,
        }

    in_channels = {
            'MNIST': 1,
            'RotatedMNIST': 1,
            'FashionMNIST': 1,
            'notMNIST': 1,
            'SVHN': 3,
            'CIFAR10': 3,
            'CIFAR100': 3,
            'CINIC10': 3,
            'TinyImagenet':3,
            'ImageNet100':3,
            'ImageNet': 3,
        }

    inp_size = {
            'MNIST': 28,
            'RotatedMNIST': 28,
            'FashionMNIST': 28,
            'notMNIST': 28,
            'SVHN': 32,
            'CIFAR10': 32,
            'CIFAR100': 32,
            'CINIC10': 32,
            'TinyImagenet':64,
            'ImageNet100':224,
            'ImageNet': 224,
        }
    return mean[dataset], std[dataset], classes[dataset],  inp_size[dataset], in_channels[dataset]