import copy
import sys

import torch
import torch.optim as optim  # Where the optimization modules are
import torchvision  # To be able to access standard datasets more easily
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms

import wandb
from networks import generate_net, SmallNet, get_learnable_parameters_as_vector, set_learnable_parameters_from_vector
from data.CUB200 import Cub200
from utils import ShuffleTransform, SequentialDataset, cifar100classes_dict, compute_replay_dataset, \
    ElasticWeightConsolidationLoss
from utils import compute_acc

from hippo import HippoApprox


def run(opt):
    ##### Hyperparameters #######
    lr = opt['lr']
    batch_size = opt['batch_size']
    epochs = opt['epochs']
    n_tasks = opt['n_tasks']
    sample_freq = opt['log_every']
    network = opt['net']
    device = torch.device(opt['device'])
    wandb_project = opt['wandb_project']
    wandb_group = opt['wandb_group']
    hippo = opt['hippo']
    approx_order = opt['order']
    dataset = opt['dataset']
    update_every = opt['update_every']
    replay = opt['replay']
    ewc = opt['ewc']
    ewc_lambda = opt['ewc_lambda']


    if hippo and replay:
        raise NotImplementedError("Hippo and replay cannot be used together")

    # Configure wandb
    if wandb_project is not None:
        wandb_log_dict={} # initialize an empty dictionary
        wandb.login()
        wdb_config_dict = {'lr': lr, 'batch_size': batch_size, 'epochs': epochs, 'n_tasks': n_tasks,
                           'network': network, 'order': approx_order, 'dataset': dataset, 'hippo': hippo}

        wandb.init(project=wandb_project, config=wdb_config_dict, group=wandb_group)


    # Now I define the transformation that I will apply to the images of the dataset to
    # create the different training and test sets of the various tasks
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    # And now I instantiate the dataset object with the defined transformation
    if dataset == 'mnist':
        assert network == "small", "Only small network working with mnist"
        # assert n_tasks <= 5, "For mnist the number of tasks must be less or equal to half of the number of classes (10)"
        transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Lambda(lambda x: x.view(-1)),
        ])
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        class_sequence = [[i, i+1] for i in range(0, 10, 2)]
        # class_sequence = []
        # for p in range(10):
        #     for p_j in range(p+1, 10):
        #         class_sequence.append((p, p_j))
        n_classes = 10

    elif dataset == 'cifar100':
        assert n_tasks <= 20, "For cifar100 the number of tasks must be less or equal to the number of super classes (20)"
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform, target_transform=lambda x: torch.as_tensor(x))
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        # in case of cifar 100 it is the first set of 5 classes belonging to the first super class
        class_sequence = [[trainset.classes.index(fc) for fc in fine_classes] for fine_classes in cifar100classes_dict.values()]
        n_classes = 100

    elif dataset == 'cub200':
        trainset = Cub200(root='./data', train=True, download=True, transform=transform)
        testset = Cub200(root='./data', train=False, download=True, transform=transform)
        class_sequence = torch.randperm(len(trainset.classes))
        class_sequence = [class_sequence[i:i + 5] for i in range(0, len(class_sequence), 5)]
        n_classes = 200

    else:
        raise NotImplementedError("The requested dataset is not implemented")

    # Then I select a neural architecture
    assert network == "resnet18" or network == "resnet50" or (dataset == "mnist" and network == "small"), \
        "The network must be either a pretrained Resnet18 or Resnet50 or a small network in case of mnist"
    if network != "small":
        preprocessor = generate_net(network, device=device)
        input_features = preprocessor.fc.in_features
        preprocessor.fc = torch.nn.Identity() # I remove the last layer of the network for the preprocessor
    else:
        preprocessor = None
        input_features = 784
    # I also initialize an empty list where I will store the various test sets to perform the evaluation
    # of the model. In this case it is using the SequentialDataset class that I defined in utils.py for working on CV tasks
    train = SequentialDataset(preprocessor=preprocessor)
    test = SequentialDataset(preprocessor=preprocessor)

    # Now this is the actual head of the network (MLP) that I train for the classification task
    net = SmallNet(device, n_classes, input_features, act=torch.nn.LeakyReLU())

    # Now we define the loss function and the optimizer
    loss = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=lr)

    ewc_loss = ElasticWeightConsolidationLoss()

    # If we are using an online approximation scheme, then I instantiate the method to perform the approximation
    HP = HippoApprox(order=approx_order, opt={'tau': 1}, device=device)

    if hippo and update_every <= 0:
        if update_every == -1:
            update_epochs = [epochs-1]
        elif update_every == -2:
            update_epochs = [0, epochs - 1]
        elif update_every == -3:
            update_epochs = [0, 1, epochs - 1]
        elif update_every == -4:
            update_epochs = [0, 1, epochs//2, epochs - 1]
        elif update_every == -5:
            update_epochs = [0, 1, 2, epochs//2, epochs - 1]
        elif update_every == -10:
            update_epochs = [0, 1, 2, 3, epochs//4, 2*epochs//4,  3*epochs//4, epochs - 3, epochs - 2, epochs - 1]
        else:
            raise NotImplementedError()

    replay_set = None
    iteration = 0
    for t in range(n_tasks):

        # Before writing down the training loops we need to select the current slice of the dataset
        class_sequence[t] = sorted(class_sequence[t])
        train_indices = [i for i in range(len(trainset)) if trainset.targets[i] in class_sequence[t]]
        cur_trainset = torch.utils.data.Subset(trainset, train_indices)
        train.append(cur_trainset, filename=f"{dataset}_{network}_train_{class_sequence[t]}")
        test_indices = [i for i in range(len(testset)) if testset.targets[i] in class_sequence[t]]
        cur_testset = torch.utils.data.Subset(testset, test_indices)
        test.append(cur_testset, filename=f"{dataset}_{network}_test_{class_sequence[t]}")

        # If method is replay we add some samples from previous
        if replay and t > 0:
            replay_set = compute_replay_dataset(net, train.get_dataset(t - 1), approx_order, n_tasks, replay_set)
            cur_trainset = ConcatDataset([train.get_dataset(t), replay_set])
        else:
            cur_trainset = train.get_dataset(t)

        # Now we define the iterable dataloader that we will use to perform learning inside each tasks
        train_dataloader = DataLoader(cur_trainset, batch_size=batch_size, shuffle=True)
        N = len(train_dataloader)  # This will be useful for logs and stuff

        for k in range(epochs):
            # When appropriate I also update the coefficients of the online approx of the weights
            for i, data in enumerate(train_dataloader):
                iteration += 1

                x_batch, y_batch = data
                inputs, y_batch = x_batch.to(device), y_batch.to(device)

                optimizer.zero_grad()
                outputs = net(inputs)

                if k == 0 and i == 0 and wandb_project is not None:
                    image = x_batch[0]
                    if dataset == "mnist":
                        image = copy.deepcopy(image.repeat(3, 1, 1).view(3, 28, 28))
                        wandb_log_dict.update({'Example Image': wandb.Image(image)})

                Risk_batch = loss(outputs, y_batch)

                # add elastic weight consolidation loss
                if ewc and t > 0:
                    Risk_batch += ewc_lambda * ewc_loss(net)

                Risk_batch.backward()
                optimizer.step()

                # update the weights and gradients for ewc loss at the last training iter of a task
                if ewc and k == epochs -1 and i == N - 1:
                    ewc_loss.update(net)

                # Now I need to compute the accuracy Let us start by computing the accuracy on the
                # training set of the current task
                with torch.no_grad():
                    if iteration % sample_freq == 0:
                        x_tr, y_tr = train.get_arrays(t)
                        x_tr, y_tr = x_tr.to(device), y_tr.to(device)
                        o_tr = net(x_tr)
                        E_risk = loss(o_tr, y_tr)
                        acc_train = compute_acc(o_tr, y_tr)
                        if wandb_project is not None:
                            wandb_log_dict.update({'acc_train': acc_train})

                        x_te, y_te = test.get_arrays(t)
                        x_te, y_te = x_te.to(device), y_te.to(device)
                        o_te = net(x_te)
                        acc_test = compute_acc(o_te, y_te)
                        if wandb_project is not None:
                            wandb_log_dict.update({'acc_test': acc_test})

                        # Finally I compute the cumulative loss
                        # Since here the dimension of the test set is always the same (10000 samples) I can compute
                        # the accuracy for each task and then average the result over the number of task to compute the
                        # cumulative accuracy
                        cml_acc = 0.
                        cml_acc_hippo = 0.
                        for m in range(t):
                            X, Y = test.get_arrays(m)
                            X, Y = X.to(device), Y.to(device)
                            saved_weights = get_learnable_parameters_as_vector(net)
                            if hippo:
                                if update_every < 0:
                                    w_approx = HP.compute_approximation((m + 1) * len(update_epochs))
                                else:
                                    w_approx = HP.compute_approximation((m+1) * epochs//update_every)
                                set_learnable_parameters_from_vector(net, w_approx)
                            out_hippo = net(X)
                            set_learnable_parameters_from_vector(net, saved_weights)
                            out = net(X)
                            acc = compute_acc(out, Y)
                            acc_hippo = compute_acc(out_hippo, Y)
                            cml_acc += acc
                            cml_acc_hippo += acc_hippo
                        cml_acc += acc_test
                        cml_acc_hippo += acc_test
                        cml_acc /= t+1
                        cml_acc_hippo /= t + 1
                        if wandb_project is not None:
                            wandb_log_dict.update({'cml_test_acc': cml_acc})
                            wandb_log_dict.update({'cml_test_acc_hippo': cml_acc_hippo})

                        if wandb_project is not None:
                            wandb.log(wandb_log_dict, step=iteration)

                        print(f'step: {iteration} task: {t} '
                              f'acc_train: {acc_train:.2f} acc_test: {acc_test:.2f} '
                              f'acc_cml: {cml_acc:.2f} acc_cml_hippo: {cml_acc_hippo:.2f}')

            with torch.no_grad():
                # update the coefficients of the HP approximation
                if hippo and update_every > 0:
                    if k % update_every == 0:
                        HP.update_coefficients(get_learnable_parameters_as_vector(net))
                elif hippo and update_every < 0:
                    if k in update_epochs:
                        HP.update_coefficients(get_learnable_parameters_as_vector(net))

        # update ewc coefficients
        if ewc:
            ewc_loss.reset()

        # if hippo:
        #     HP.update_coefficients(net.get_learnable_parameters_as_vector())




