import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import copy
import matplotlib as plt 
def get_task_load_train(train_dataset,batch_size):
    train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size,
    num_workers=0,
    pin_memory=True, shuffle=True)
    print('Train loader length', len(train_loader))    
    return train_loader

def get_task_load_test(test_dataset,test_batch_size):
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        test_batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True)

    return test_loader

def load_data():
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
    train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                                    (4,4,4,4),mode='reflect').squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
         normalize
    ])

    full_dataset_cifar10 = datasets.CIFAR10('_data', train=True, transform=train_transform, download=True)
    test_dataset_cifar10 = datasets.CIFAR10('_data', train=False, transform=test_transform, download=False)

    full_dataset_cifar100 = datasets.CIFAR100('_data', train=True, transform=train_transform, download=True)
    test_dataset_cifar100 = datasets.CIFAR100('_data', train=False, transform=test_transform, download=False)
    return full_dataset_cifar10,test_dataset_cifar10,full_dataset_cifar100,test_dataset_cifar100

def load_cifar10():
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
    train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                                    (4,4,4,4),mode='reflect').squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
         normalize
    ])

    full_dataset_cifar10 = datasets.CIFAR10('_data', train=True, transform=train_transform, download=True)
    test_dataset_cifar10 = datasets.CIFAR10('_data', train=False, transform=test_transform, download=False)

    return full_dataset_cifar10, test_dataset_cifar10

def load_cifar100():
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
    train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                                    (4,4,4,4),mode='reflect').squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
         normalize
    ])
    full_dataset_cifar100 = datasets.CIFAR100('_data', train=True, transform=train_transform, download=True)
    test_dataset_cifar100 = datasets.CIFAR100('_data', train=False, transform=test_transform, download=False)
    return full_dataset_cifar100, test_dataset_cifar100


def task_construction(task_labels, target_task_labels, benchmark):
    if benchmark=='CIFAR10':
        train_dataset_cifar10, test_dataset_cifar10 = load_cifar10()
        train_dataset = split_dataset_by_labels(train_dataset_cifar10, task_labels, target_task_labels)
        test_dataset = split_dataset_by_labels(test_dataset_cifar10, task_labels, target_task_labels)
    elif benchmark=='CIFAR100':
        train_dataset_cifar100, test_dataset_cifar100 = load_cifar100()
        train_dataset = split_dataset_by_labels(train_dataset_cifar100, task_labels, target_task_labels)
        test_dataset = split_dataset_by_labels(test_dataset_cifar100, task_labels, target_task_labels)
    else:
        full_dataset_cifar10,test_dataset_cifar10,full_dataset_cifar100,test_dataset_cifar100=load_data()
        train_dataset = split_dataset_by_labels_cifar10_100(full_dataset_cifar10,full_dataset_cifar100, task_labels)
        test_dataset=split_dataset_by_labels_cifar10_100(test_dataset_cifar10,test_dataset_cifar100 , task_labels)
    return train_dataset,test_dataset

def split_dataset_by_labels(dataset, task_labels, target_task_labels):
    datasets = []
    for labels, target_labels in zip(task_labels, target_task_labels):
        idx = np.in1d(dataset.targets, labels)
        splited_dataset = copy.deepcopy(dataset)
        targets = change_labels(labels, target_labels, torch.LongTensor(splited_dataset.targets)[idx])
        splited_dataset.targets = targets
        splited_dataset.data = splited_dataset.data[idx]
        #for i in range(10):
        #    plt.pyplot.imshow(splited_dataset.data[i])
        #    plt.pyplot.show()
        datasets.append(splited_dataset)
    return datasets

def change_labels(current_labels, target_labels, targets):
    new_targets=copy.deepcopy(targets)
    for i in range(len(current_labels)):
        new_targets[targets==current_labels[i]] = target_labels[i]
    return new_targets


def split_dataset_by_labels_cifar10_100(dataset_1,dataset_2, task_labels):
    datasets = []
    task_id = 0
    for labels in task_labels:
        if task_id==0:
            task_id+=1
            idx=np.in1d(dataset_1.targets, labels)
            splited_dataset=copy.deepcopy(dataset_1)
            splited_dataset.targets = (torch.FloatTensor(splited_dataset.targets)[idx]).tolist()
            splited_dataset.data = splited_dataset.data[idx]
        else:
            idx=np.in1d(dataset_2.targets, labels)
            splited_dataset=copy.deepcopy(dataset_2)
            splited_dataset.targets = (torch.FloatTensor(splited_dataset.targets)[idx]+10).tolist()
            splited_dataset.data = splited_dataset.data[idx]
        datasets.append(splited_dataset)
    return datasets



