import torch
import torch.utils.data as data
from torchvision.datasets import MNIST, CIFAR10, SVHN, CIFAR100
import torchvision.transforms as transforms
import numpy as np
from scipy import ndimage

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# TODO: This is inconsistent with the fact that we specify the device in config file

def rotate_mnist(train_image, angle):
    train_image = train_image.numpy()
    train_image = np.array(train_image, dtype='float')
    x, r, c = train_image.shape
    #print(train_image.shape)
    #train_image = train_image.reshape((r, c))

    train_image = np.array([
        ndimage.rotate(train_image[i],
                       int(angle),
                       order=1,
                       reshape=False,
                       axes=(0, 1)) for i in range(0, x)
    ])
    #print(train_image.shape)
    #train_image = train_image.reshape(x, r, c)

    return torch.tensor(train_image).type(torch.FloatTensor)


def get_single_dataloader(dataset, datadir, train_bs, test_bs, transform,
                          angle, seed_permutation, partition):

    if dataset == 'MNIST':
        dl_obj = MNIST
        mean = (0.1307, )
        std = (0.3081, )
        dim = 28
        outchannel = 1
    elif dataset == 'CIFAR10':
        dl_obj = CIFAR10
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
        outchannel = 3
        dim = 32
    elif dataset == 'SVHN':
        dl_obj = SVHN
        mean = [0.4377, 0.4438, 0.4728]
        std = [0.198, 0.201, 0.197]
        outchannel = 3
        dim = 32
    elif dataset == 'CIFAR100':
        dl_obj = CIFAR100
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
        outchannel = 3
        dim = 32

    else:
        raise Exception('Unknown dataset')

    if transform is None:

        transform = transforms.Compose([
            transforms.Resize((dim, dim)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])


    elif transform == '3channel':
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor()
        ])
        outchannel = 3

    elif transform == '1channel':
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Grayscale(num_output_channels=1)
        ])
        outchannel = 1

    elif transform == 'rotate':
        transform = transforms.Compose([
            transforms.Resize((dim, dim)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: rotate_mnist(x, angle))
        ])

    elif transform == 'permutation':
        gen = torch.Generator()
        gen.manual_seed(seed_permutation)
        idx_permute = torch.randperm(outchannel * dim * dim, generator=gen)

        transform = transforms.Compose([
            transforms.Resize((dim, dim)),
            transforms.ToTensor(),
            transforms.Lambda(
                lambda x: x.view(-1)[idx_permute].view(outchannel, dim, dim))
        ])
    else:
        raise ('unknow transformation')

    if dataset == 'SVHN':
        train_ds = dl_obj(datadir,
                          split='train',
                          transform=transform,
                          target_transform=None,
                          download=True)
        test_ds = dl_obj(datadir,
                         split='test',
                         transform=transform,
                         target_transform=None,
                         download=True)

    else:

        train_ds = dl_obj(datadir,
                          train=True,
                          transform=transform,
                          download=True)
        test_ds = dl_obj(datadir,
                         train=False,
                         transform=transform,
                         download=True)

    if device.type == "cpu":
        num_workers_train = 0
        num_workers_test = 0
    else:
        num_workers_train = 16
        num_workers_test = 4


    if partition is not None:
        print('partition', partition)
        #print('type partition',type(partition))

        if type(partition) is tuple:
            start_class = partition[0]
            end_class = partition[1]

            targets_train = torch.tensor(train_ds.targets)
            target_train_idx = ((targets_train >= start_class) &
                                (targets_train < end_class))

            targets_test = torch.tensor(test_ds.targets)
            target_test_idx = ((targets_test >= start_class) &
                               (targets_test < end_class))
        else:


            if len(partition)!=5:
                raise Exception('No implementation for len list partition not equal to 5')

            class_0 = partition[0]
            class_1 = partition[1]
            class_2 = partition[2]
            class_3 = partition[3]
            class_4 = partition[4]

            targets_train = torch.tensor(train_ds.targets)
            target_train_idx = (targets_train == class_0) | (targets_train == class_1) | (targets_train == class_2) | (targets_train == class_3) | (targets_train == class_4)
            #print('here',len(target_train_idx))

            targets_test = torch.tensor(test_ds.targets)
            target_test_idx = (targets_test == class_0) | (targets_test == class_1) | (targets_test == class_2) | (targets_test == class_3) | (targets_test == class_4)


        trainset_1 = data.dataset.Subset(train_ds,
                                         np.where(target_train_idx == 1)[0])
        train_dl = torch.utils.data.DataLoader(trainset_1, batch_size=train_bs, num_workers=num_workers_train)

        trainset_2 = data.dataset.Subset(test_ds,
                                         np.where(target_test_idx == 1)[0])
        test_dl = torch.utils.data.DataLoader(trainset_2, batch_size=test_bs, num_workers=num_workers_test)

        #for i, (inputs, labels) in enumerate(train_dl):
        #   print(labels)

    else:

        train_dl = data.DataLoader(dataset=train_ds,
                                   batch_size=train_bs,
                                   shuffle=False, num_workers=num_workers_train)
        #print(len(train_dl))

        test_dl = data.DataLoader(dataset=test_ds,
                                  batch_size=test_bs,
                                  shuffle=False, num_workers=num_workers_test)

    return train_dl, test_dl


def generate_tasks(tasks_description, datadir, tr_batch_size, ts_batch_size,sim):

    train_dl_list = []
    test_dl_list = []

    for (dataset, transform, angle, seed_permutation,
         partition) in tasks_description:
        #print(dataset,transform,angle,seed_permutation)

        train_dl, test_dl = get_single_dataloader(dataset, datadir,
                                                  tr_batch_size, ts_batch_size,
                                                  transform, angle,
                                                  seed_permutation, partition)

        train_dl_list.append(train_dl)
        test_dl_list.append(test_dl)
        # torch.manual_seed(sim+1) # to avoid seed of permutation to change the seed


    return train_dl_list, test_dl_list
