
import os
import numpy as np
from torchvision import transforms
from torch.utils.data import ConcatDataset, Dataset, Subset
import torch, torchvision


def get_multitask_experiment(config):
    """ Get training and test datasets for continual learning experiments. 
        Args:
            config (dict): Configuration file for experiment.
        Returns:
            train_datasets (list): Training datasets
            test_datasets (list): Test datasets
            classes_per_task (int): Number of classes per task in dataset.
    """
    dataset = config['data']['name']
    data_dir = config['data']['data_dir']
    n_tasks = config['data']['n_tasks']
    scenario = config['training']['scenario']
    img_size = config['data']['img_size']
    total_num_classes = config['data']['n_classes'] 
    shuffle_labels = config['data']['shuffle_labels']
    pc_valid = config['data']['pc_valid'] # percentage of training data to validation set
    seed = config['session']['seed'] # used for train/val split

    # Transforms, not using any really
    if dataset in ['notMNIST']:
        data_transforms = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),])
    else:
        data_transforms = transforms.ToTensor()
    train_transforms = data_transforms
    test_transforms = data_transforms

    # Get original datasets
    train_dataset = get_dataset(dataset, data_dir, train=True, transforms=train_transforms)
    test_dataset = get_dataset(dataset, data_dir, train=False, transforms=test_transforms)

    # Get fixed parameters for datasets
    if dataset in ['MNIST', 'CIFAR10', 'FashionMNIST', 'notMNIST', ]:
        n_tasks = 5
        classes_per_task = 2
    elif dataset in ['CIFAR100', 'miniImagenet',]:
        n_tasks = 20
        classes_per_task = 5
    elif dataset in ['GroceryStore',]:
        # use given n_tasks
        classes_per_task = int(np.floor(total_num_classes / n_tasks))

    # Get subdatasets for continual learning
    if dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'FashionMNIST', 'notMNIST', 'GroceryStore', 'miniImagenet']:
        # check for number of tasks
        if (n_tasks > total_num_classes):
            raise ValueError("Experiment %s cannot have more than %d tasks!" %(dataset, total_num_classes))
        #classes_per_task = int(np.floor(total_num_classes / n_tasks))
        # generate labels-per-task
        labels_per_task = [
            list(np.array(range(classes_per_task)) + classes_per_task * task_id) for task_id in range(n_tasks)
        ]
        print('Labels for each task: ', labels_per_task)
        # split them up into sub-tasks
        train_datasets = []
        test_datasets = []
        for labels in labels_per_task:
            target_transform = transforms.Lambda(lambda y, x=labels[0]: y - x) if scenario=='domain' else None
            train_datasets.append(SubDataset(train_dataset, labels, target_transform=target_transform))
            test_datasets.append(SubDataset(test_dataset, labels, target_transform=target_transform))

    elif dataset == 'PermutedMNIST':
        # Reduce the training sets to 10k samples
        ind = np.random.permutation(len(train_dataset))[:10000]
        train_dataset = Subset(train_dataset, indices=ind)
        
        classes_per_task = 10 # same as original MNIST
        # generate permutations, first task is original MNIST
        permutations = [None] + [np.random.permutation(img_size**2) for _ in range(n_tasks-1)]
        # prepare datasets per task
        train_datasets = []
        test_datasets = []
        for task_id, perm in enumerate(permutations):
            target_transform = transforms.Lambda(lambda y, x=task_id: y + x*classes_per_task) if scenario in ('task', 'class') else None
            train_datasets.append(TransformedDataset(
                train_dataset, transform=transforms.Lambda(lambda x, p=perm: _permutate_image_pixels(x, p)),
                target_transform=target_transform
            ))
            test_datasets.append(TransformedDataset(
                test_dataset, transform=transforms.Lambda(lambda x, p=perm: _permutate_image_pixels(x, p)),
                target_transform=target_transform
            ))
    else:
        raise RuntimeError('Given undefined dataset: {}'.format(dataset))

    # Get validation set
    valid_datasets = []
    for task_id, train_set in enumerate(train_datasets):
        split = int(np.floor(pc_valid * len(train_set)))
        train_split, valid_split = torch.utils.data.random_split(train_set,
                                                                lengths=[len(train_set) - split, split],
                                                                generator=torch.Generator().manual_seed(seed))
        train_datasets[task_id] = train_split
        valid_datasets.append(valid_split)

    return train_datasets, valid_datasets, test_datasets, classes_per_task

def get_dataset(dataset, data_dir, train, transforms, target_transforms=None):
    """ Get torchvision datasets.
        Args:
            dataset (str): Dataset name.
            data_dir (str): Directory where dataset is downloaded to.
            train (bool): Get training or test data?
            transforms (transforms.Compose): Composition of data transforms. 
            target_transforms (...): Transforms of target label.
        Returns:
            (torch.Dataset): The wanted dataset.
    """

    # Support for *some* pytorch default loaders is provided. Code is made such that adding new datasets is super easy, given they are in ImageFolder format.        
    if dataset in ['CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST']:
        return getattr(torchvision.datasets, dataset)(root=data_dir, train=train, download=True,
                                                        transform=transforms, target_transform=target_transforms)
    elif dataset =='PermutedMNIST':
        return getattr(torchvision.datasets, 'MNIST')(root=data_dir, train=train, download=True,
                                                        transform=transforms, target_transform=target_transforms)
    elif dataset=='SVHN':
        split = 'train' if train else 'test'
        return getattr(torchvision.datasets, dataset)(root=data_dir, split=split, download=True, transform=transforms, target_transform=target_transforms) 
    else: # e.g. ['notMNIST', 'miniImagenet]
        subfolder = 'train' if train else 'test' # ImageNet 'val' is labled as 'test' here.
        return torchvision.datasets.ImageFolder(data_dir+'/'+dataset+'/'+subfolder, transform=transforms, target_transform=target_transforms)

#----------------------------------------------------------------------------------------------------------#

class SubDataset(Dataset):
    '''To sub-sample a dataset, taking only those samples with label in [sub_labels].

    After this selection of samples has been made, it is possible to transform the target-labels,
    which can be useful when doing continual learning with fixed number of output units.'''

    def __init__(self, original_dataset, sub_labels, target_transform=None):
        super().__init__()
        self.dataset = original_dataset
        self.sub_indeces = []
        for index in range(len(self.dataset)):
            if hasattr(original_dataset, "targets"):
                if self.dataset.target_transform is None:
                    label = self.dataset.targets[index]
                else:
                    label = self.dataset.target_transform(self.dataset.targets[index])
            else:
                label = self.dataset[index][1]
            if label in sub_labels:
                self.sub_indeces.append(index)
        self.target_transform = target_transform

    def __len__(self):
        return len(self.sub_indeces)

    def __getitem__(self, index):
        sample = self.dataset[self.sub_indeces[index]]
        if self.target_transform:
            target = self.target_transform(sample[1])
            sample = (sample[0], target)
        return sample

class TransformedDataset(Dataset):
    '''Modify existing dataset with transform; for creating multiple MNIST-permutations w/o loading data every time.'''

    def __init__(self, original_dataset, transform=None, target_transform=None):
        super().__init__()
        self.dataset = original_dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        (input, target) = self.dataset[index]
        if self.transform:
            input = self.transform(input)
        if self.target_transform:
            target = self.target_transform(target)
        return (input, target)

#----------------------------------------------------------------------------------------------------------#

################## PERMUTED MNIST CODE ##################

def _permutate_image_pixels(image, permutation):
    '''Permutate the pixels of an image according to [permutation].

    [image]         3D-tensor containing the image
    [permutation]   <ndarray> of pixel-indeces in their new order'''

    if permutation is None:
        return image
    else:
        c, h, w = image.size()
        image = image.view(c, -1)
        image = image[:, permutation]  #--> same permutation for each channel
        image = image.view(c, h, w)
        return image

