import os
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

ANGLE_VARIANCE = 10
CIRCLE_ANGLE = 360


def get_dataset(num_samples, clf_angle, num_dims=2):
    dataset_mean = np.zeros(num_dims)
    dataset_cov = np.eye(num_dims)
    raw_data_trn = np.random.multivariate_normal(mean=dataset_mean, cov=dataset_cov, size=num_samples)
    raw_data_tst = np.random.multivariate_normal(mean=dataset_mean, cov=dataset_cov, size=num_samples)
    if num_dims == 2:
        angle_to_slope = np.sin(np.radians(clf_angle))
        dataset_labels = (raw_data_trn[:, 1] >= angle_to_slope * raw_data_trn[:, 0]) + 0
        dataset_labels_tst = (raw_data_tst[:, 1] >= angle_to_slope * raw_data_tst[:, 0]) + 0
    else:
        angle_to_slope = np.cos(np.radians(clf_angle))
        dataset_labels = (raw_data_trn[:, 0] >= angle_to_slope * raw_data_trn[:, 1]) + 0
        dataset_labels_tst = (raw_data_tst[:, 0] >= angle_to_slope * raw_data_tst[:, 1]) + 0
    return torch.tensor(raw_data_trn, dtype=torch.float32), \
           F.one_hot(torch.as_tensor(dataset_labels, dtype=torch.long), num_classes=2), torch.tensor(raw_data_tst,
                                                                                                     dtype=torch.float32), \
           F.one_hot(torch.as_tensor(dataset_labels_tst, dtype=torch.long), num_classes=2)


def get_positive_correlated_taskset(num_samples, num_tasks, num_dims=2):
    main_angle = np.random.random() * CIRCLE_ANGLE - ANGLE_VARIANCE
    for i in range(num_tasks):
        yield get_dataset(num_samples, main_angle + np.random.random() * ANGLE_VARIANCE, num_dims)


def get_rand_taskset(num_samples, num_tasks, num_dims=2):
    for i in range(num_tasks):
        yield get_dataset(num_samples, np.random.random() * CIRCLE_ANGLE, num_dims)


def get_positive_correlated_with_distractors_taskset(num_samples, num_tasks, frac_distractors, num_dims=2):
    main_angle = np.random.random() * CIRCLE_ANGLE - ANGLE_VARIANCE
    for i in range(num_tasks):
        if random.random() <= frac_distractors:  # add a distractor
            yield get_dataset(num_samples, main_angle + np.random.random() * ANGLE_VARIANCE + CIRCLE_ANGLE / 2,
                              num_dims)
        else:
            yield get_dataset(num_samples, main_angle + np.random.random() * ANGLE_VARIANCE, num_dims)


def get_gradual_shift_taskset(num_samples, num_tasks, num_dims=2):
    main_angle = np.random.random() * CIRCLE_ANGLE
    for i in range(num_tasks):
        yield get_dataset(num_samples, main_angle, num_dims)


def get_orthogonal_taskset(num_samples, num_tasks, alternating=True, num_dims=2):
    main_angle = np.random.random() * CIRCLE_ANGLE
    orthogonal_angle = main_angle + 90
    if alternating:
        for i in range(num_tasks):
            angle = main_angle if (i % 2 == 0) else orthogonal_angle
            yield get_dataset(num_samples, angle, num_dims)
    else:
        for i in range(num_tasks):
            angle = main_angle if (i < (num_tasks // 2)) else orthogonal_angle
            yield get_dataset(num_samples, angle, num_dims)


def reshape_and_onehot(train_data, train_labels, total_labels, reshape_array):
    if reshape_array is None:
        return train_data.flatten(1), F.one_hot(train_labels, total_labels)
    else:
        return train_data.reshape(-1, *reshape_array), F.one_hot(train_labels, total_labels)


def get_random_subsample(dataset, num_samples, total_labels, allowed_labels=None, reshape_array=None,
                         is_binarized=False):
    if allowed_labels is None:
        sampled_inds = torch.randperm(len(dataset.targets))[:num_samples]
        return reshape_and_onehot(dataset.data[sampled_inds].float(), dataset.targets[sampled_inds], total_labels,
                                  reshape_array)
    is_binary = (total_labels == 2)
    label_inds = [(dataset.targets == l).nonzero() for l in (allowed_labels if is_binary else range(total_labels))]
    perms = [torch.randperm(len(inds)) for inds in label_inds]
    subsampled_inds = [label_inds[i][perms[i][:(num_samples // total_labels)]] for i in range(total_labels)]
    datas = [dataset.data[inds] for inds in subsampled_inds]
    labels = [dataset.targets[inds] for inds in subsampled_inds]
    label_tensor = (torch.cat(labels) == allowed_labels[0]).long().flatten() if is_binary \
        else torch.isin(torch.cat(labels), torch.tensor(allowed_labels)).long().flatten()
    return reshape_and_onehot(torch.cat(datas, dim=0).float(), label_tensor, total_labels if not is_binarized else 2,
                              reshape_array)


def get_multiclass_permuted_mnist(num_samples, num_tasks, flatten=True):
    im_shape = (1, 28, 28)
    target_trans = []
    for i in range(num_tasks):
        data_transforms = [create_pixel_permute_trans(im_shape)]
        train_dataset, test_dataset = load_MNIST(data_transforms, "./data", target_trans=target_trans)
        d_train, l_train = get_random_subsample(train_dataset, num_samples, 10,
                                                reshape_array=im_shape if not flatten else None)
        t = test_dataset.data.reshape(-1, *im_shape)
        if flatten:
            t = t.flatten(1)
        yield d_train, l_train, t.float(), F.one_hot(test_dataset.targets, 10)


def get_split_mnist(num_samples, num_tasks, grouped=True, flatten=True):
    # Note: This should be a sequence of 5 binary tasks and loss is average of all five
    # If grouped, we make a mix of 5 tasks as one single binary task, if not grouped we run all 5 tasks sequentially
    data_transforms = []
    target_trans = []
    im_shape = (1, 28, 28)
    i = 0
    while i < num_tasks:
        train_dataset, test_dataset = load_MNIST(data_transforms, "./data", target_trans=target_trans)
        labels = np.random.choice(np.arange(10), size=5, replace=False)
        if grouped:
            trn_data, trn_labels = get_random_subsample(train_dataset, num_samples, 10, labels,
                                                        reshape_array=im_shape if not flatten else None,
                                                        is_binarized=True)
            tst_data, tst_labels = get_random_subsample(test_dataset, len(test_dataset.targets), 10, labels,
                                                        reshape_array=im_shape if not flatten else None,
                                                        is_binarized=True)
            yield trn_data, trn_labels, tst_data, tst_labels
            i += 1
        else:
            other_labels = np.random.permutation(np.arange(10)[~np.isin(np.arange(10), labels)])
            for j, l in enumerate(labels):
                actual_labels = [other_labels[j], l]
                trn_data, trn_labels = get_random_subsample(train_dataset, num_samples, 2, actual_labels,
                                                            reshape_array=im_shape if not flatten else None)
                tst_data, tst_labels = get_random_subsample(test_dataset, len(test_dataset.targets), 2, actual_labels,
                                                            reshape_array=im_shape if not flatten else None)
                yield trn_data, trn_labels, tst_data, tst_labels
                i += 1


def load_MNIST(final_input_trans, data_path, target_trans=[]):
    # Data transformations list:
    transform = [transforms.ToTensor()]

    # Normalize values:
    # Note: original values  in the range [0,1]

    # MNIST_MEAN = (0.1307,)  # (0.5,)
    # MNIST_STD = (0.3081,)  # (0.5,)
    # transform += transforms.Normalize(MNIST_MEAN, MNIST_STD)

    transform += [transforms.Normalize((0.5,), (0.5,))]  # transform to [-1,1]

    if final_input_trans:
        transform += final_input_trans

    root_path = os.path.join(data_path, 'MNIST')

    # Train set:
    train_dataset = datasets.MNIST(root_path, train=True, download=True,
                                   transform=transforms.Compose(transform),
                                   target_transform=transforms.Compose(target_trans))

    # Test set:
    test_dataset = datasets.MNIST(root_path, train=False,
                                  transform=transforms.Compose(transform),
                                  target_transform=transforms.Compose(target_trans))

    return train_dataset, test_dataset


def create_pixel_permute_trans(input_shape):
    input_size = input_shape[0] * input_shape[1] * input_shape[2]
    inds_permute = torch.randperm(input_size)
    transform_func = lambda x: permute_pixels(x, inds_permute, is_color=(input_shape[0] > 1))
    return transform_func


def create_limited_pixel_permute_trans(input_shape, num_pixels):
    input_size = input_shape[0] * input_shape[1] * input_shape[2]
    inds_permute = torch.LongTensor(np.arange(0, input_size))

    for i_shuffle in range(num_pixels):
        i1 = np.random.randint(0, input_size)
        i2 = np.random.randint(0, input_size)
        temp = inds_permute[i1]
        inds_permute[i1] = inds_permute[i2]
        inds_permute[i2] = temp

    transform_func = lambda x: permute_pixels(x, inds_permute, is_color=(input_shape[0] > 1))
    return transform_func


def create_label_permute_trans(n_class):
    inds_permute = torch.randperm(n_class)
    transform_func = lambda target: inds_permute[target]
    return transform_func


def permute_pixels(x, inds_permute, is_color=False):
    ''' Permute pixels of a tensor image'''
    im_C = x.shape[0]
    im_H = x.shape[1]
    im_W = x.shape[2]
    input_size = im_C * im_H * im_W
    x = x.view(input_size)  # flatten image
    x = x[inds_permute]
    x = x.view(3 if is_color else 1, im_H, im_W)
    return x

def get_multiclass_permuted_cifar10(num_samples, num_tasks, flatten=True):
    im_shape = (3, 32, 32)
    target_trans = []
    for i in range(num_tasks):
        data_transforms = [create_pixel_permute_trans(im_shape)]
        train_dataset, test_dataset = load_cifar(data_transforms, "./data", target_trans=target_trans)
        train_dataset.data = torch.tensor(train_dataset.data)
        test_dataset.data = torch.tensor(test_dataset.data)
        train_dataset.targets = torch.tensor(train_dataset.targets)
        test_dataset.targets = torch.tensor(test_dataset.targets)
        d_train, l_train = get_random_subsample(train_dataset, num_samples, 10,
                                                reshape_array=im_shape if not flatten else None)
        t = test_dataset.data.reshape(-1, *im_shape)
        if flatten:
            t = t.flatten(1)
        yield d_train, l_train, t.float(), F.one_hot(test_dataset.targets, 10)

def get_split_cifar10(num_samples, num_tasks, grouped=True, flatten=True):
    # Note: This should be a sequence of 5 binary tasks and loss is average of all five
    # If grouped, we make a mix of 5 tasks as one single binary task, if not grouped we run all 5 tasks sequentially
    data_transforms = []
    target_trans = []
    im_shape = (3, 32, 32)
    i = 0
    while i < num_tasks:
        train_dataset, test_dataset = load_cifar(data_transforms, "./data", target_trans=target_trans)
        train_dataset.data = torch.tensor(train_dataset.data)
        test_dataset.data = torch.tensor(test_dataset.data)
        train_dataset.targets = torch.tensor(train_dataset.targets)
        test_dataset.targets = torch.tensor(test_dataset.targets)
        labels = np.random.choice(np.arange(10), size=5, replace=False)
        if grouped:
            trn_data, trn_labels = get_random_subsample(train_dataset, num_samples, 10, labels,
                                                        reshape_array=im_shape if not flatten else None,
                                                        is_binarized=True)
            tst_data, tst_labels = get_random_subsample(test_dataset, len(test_dataset.targets), 10, labels,
                                                        reshape_array=im_shape if not flatten else None,
                                                        is_binarized=True)
            yield trn_data, trn_labels, tst_data, tst_labels
            i += 1
        else:
            other_labels = np.random.permutation(np.arange(10)[~np.isin(np.arange(10), labels)])
            for j, l in enumerate(labels):
                actual_labels = [other_labels[j], l]
                trn_data, trn_labels = get_random_subsample(train_dataset, num_samples, 2, actual_labels,
                                                            reshape_array=im_shape if not flatten else None)
                tst_data, tst_labels = get_random_subsample(test_dataset, len(test_dataset.targets), 2, actual_labels,
                                                            reshape_array=im_shape if not flatten else None)
                yield trn_data, trn_labels, tst_data, tst_labels
                i += 1

def get_split_cifar100(num_samples, num_tasks, grouped=True, flatten=True):
    # Note: This should be a sequence of 5 binary tasks and loss is average of all five
    # If grouped, we make a mix of 5 tasks as one single binary task, if not grouped we run all 5 tasks sequentially
    data_transforms = []
    target_trans = []
    im_shape = (3, 32, 32)
    i = 0
    while i < num_tasks:
        train_dataset, test_dataset = load_cifar100(data_transforms, "./data", target_trans=target_trans)
        train_dataset.data = torch.tensor(train_dataset.data)
        test_dataset.data = torch.tensor(test_dataset.data)
        train_dataset.targets = torch.tensor(train_dataset.targets)
        test_dataset.targets = torch.tensor(test_dataset.targets)
        labels = np.random.choice(np.arange(100), size=50, replace=False)
        if grouped:
            trn_data, trn_labels = get_random_subsample(train_dataset, num_samples, 100, labels,
                                                        reshape_array=im_shape if not flatten else None,
                                                        is_binarized=True)
            tst_data, tst_labels = get_random_subsample(test_dataset, len(test_dataset.targets), 100, labels,
                                                        reshape_array=im_shape if not flatten else None,
                                                        is_binarized=True)
            yield trn_data, trn_labels, tst_data, tst_labels
            i += 1
        else:
            other_labels = np.random.permutation(np.arange(100)[~np.isin(np.arange(100), labels)])
            for j, l in enumerate(labels):
                actual_labels = [other_labels[j], l]
                trn_data, trn_labels = get_random_subsample(train_dataset, num_samples, 2, actual_labels,
                                                            reshape_array=im_shape if not flatten else None)
                tst_data, tst_labels = get_random_subsample(test_dataset, len(test_dataset.targets), 2, actual_labels,
                                                            reshape_array=im_shape if not flatten else None)
                yield trn_data, trn_labels, tst_data, tst_labels
                i += 1


def load_cifar(final_input_trans, data_path, target_trans=[]):
    # Data transformations list:
    transform = [transforms.ToTensor()]
    transform += [transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5,))]  # transform to [-1,1]

    if final_input_trans:
        transform += final_input_trans

    root_path = os.path.join(data_path, 'CIFAR10')

    # Train set:
    train_dataset = datasets.CIFAR10(root_path, train=True, download=False,
                                     transform=transforms.Compose(transform),
                                     target_transform=transforms.Compose(target_trans))

    # Test set:
    test_dataset = datasets.CIFAR10(root_path, train=False,
                                    transform=transforms.Compose(transform),
                                    target_transform=transforms.Compose(target_trans))

    return train_dataset, test_dataset

def load_cifar100(final_input_trans, data_path, target_trans=[]):
    # Data transformations list:
    transform = [transforms.ToTensor()]
    transform += [transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5,))]  # transform to [-1,1]

    if final_input_trans:
        transform += final_input_trans

    root_path = os.path.join(data_path, 'CIFAR100')

    # Train set:
    train_dataset = datasets.CIFAR100(root_path, train=True, download=True,
                                     transform=transforms.Compose(transform),
                                     target_transform=transforms.Compose(target_trans))

    # Test set:
    test_dataset = datasets.CIFAR100(root_path, train=False, download=True,
                                    transform=transforms.Compose(transform),
                                    target_transform=transforms.Compose(target_trans))

    return train_dataset, test_dataset
