import torch
import torch.optim as optim  # Where the optimization modules are
from torchvision.transforms import ToTensor

import wandb
from utils import ShuffleTransform
from utils import compute_acc, compute_replay_dataset
import torchvision  # To be able to access standard datasets more easily
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from networks import generate_net, get_learnable_parameters_as_vector, set_learnable_parameters_from_vector
import sys
from hippo import HippoApprox


def run(opt):
    ##### Hyperparameters #######
    lr = opt['lr']
    batch_size = opt['batch_size']
    epochs = opt['epochs']
    n_tasks = opt['n_tasks']
    sample_freq = opt['log_every']
    network = opt['net']
    dataset = opt['dataset']
    device = torch.device(opt['device'])
    n_perm_pix = opt['n_perm_pix']
    wandb_project = opt['wandb_project']
    wandb_group = opt['wandb_group']
    hippo = opt['hippo']
    tau = opt['tau']
    approx_order = opt['order']
    weight_decay = opt['weight_decay']
    replay = opt["replay"]
    ##############################
    test_batch = 10000
    assert (hippo and not replay) or (not hippo and replay)


    # Configure wandb
    if wandb_project is not None:
        wandb_log_dict={} # initialize an empty dictionary
        wandb.login()
        wdb_config_dict = {'lr': lr, 'batch_size': batch_size, 'epochs': epochs, 'n_tasks': n_tasks,\
                           'network': network, 'n_perm_pix': n_perm_pix, 'hippo': hippo, 'approx_order': approx_order}

        run_wandb = wandb.init(project=wandb_project, config=wdb_config_dict, group=wandb_group)


    # Now I instantiate a class to handle the shuffling of the pixels of the images
    if dataset == 'mnist' or dataset == 'fmnist':
        shuffle = ShuffleTransform(image_size=28, channels=1)
    elif dataset == 'cifar10':
        shuffle = ShuffleTransform(image_size=32, channels=3)
    # And I define the transformation that I will apply to the images of the dataset to
    # create the different training and test sets of the various tasks
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Lambda(shuffle.shuffle_image)])

    # And now I instantiate the dataset object with the defined transformation
    if dataset == 'mnist':
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform, target_transform=lambda x: torch.as_tensor(x))
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform, target_transform=lambda x: torch.as_tensor(x))
    elif dataset == 'fmnist':
        trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform, target_transform=lambda x: torch.as_tensor(x))
        testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform, target_transform=lambda x: torch.as_tensor(x))
    elif dataset == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform, target_transform=lambda x: torch.as_tensor(x))
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform, target_transform=lambda x: torch.as_tensor(x))



    # I also initialize an empty list where I will store the various test sets to perform the evaluation
    # of the model
    test = []
    train = []

    # Then I select a neural architecture
    if dataset == 'mnist' or dataset == 'fmnist':
        net = generate_net(network, device=device, hidden_size=[100])
    elif dataset == 'cifar10':
        net = generate_net(network, device=device, pretrained=False)


    # Now we define the loss function and the optimizer
    loss = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

    # If we are using an online approximation scheme, then I instantiate the method to perform the approximation
    HP = HippoApprox(order=approx_order, opt={'tau': tau}, device=device)


    for t in range(n_tasks):
        # For each task I will first of all maintain the arrays that contains all the train set and all the test set
        # in order to be able to compute metrics on them
        trainset_whole = next(iter(DataLoader(trainset, batch_size=len(trainset), shuffle=False)))
        testset_whole = next(iter(DataLoader(testset, batch_size=len(testset), shuffle=False)))

        train.append(trainset_whole) # now test is an array of lists with two elements, the first element is a tensor of shape
                                     # torch.Size([60000, 1, 28, 28]), the second of size torch.Size([10000]) (and it contains the labels)
        test.append(testset_whole)  # same as above sizes change
        # Notice that since that Shuffle.pixel_indices is initialized to an ordered grid in the first task there will be
        # no shuffling

        # If we are using replay buffer
        if replay and t > 0:
            last_trainset = TensorDataset(train[t - 1][0], train[t - 1][1])
            replay_dataset = compute_replay_dataset(model=net, last_trainset=last_trainset, order=approx_order,
                                                        tasks=n_tasks, old_replay_dataset=None)

            trainset = ConcatDataset([trainset, replay_dataset])

        # Now we define the iterable dataloader with the correct batch sizes that we will use to perform
        # learning inside each tasks
        train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
        N = len(train_dataloader)  # This will be useful for logs and stuff

        # Before writing down the training loops we need to update the shuffled  indices to be ready for the next
        # task
        if t > 0:
            shuffle.update_pixels_indices(n_perm_pix)



        for k in range(epochs):
            # When appropriate I also update the coefficients of the online approx of the weights
            for i, data in enumerate(train_dataloader):

                iteration = N * epochs * t + N * k + i

                x_batch, y_batch = data
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                optimizer.zero_grad()
                #inputs = torch.flatten(x_batch, start_dim=1, end_dim=3) # I think this was unnecessary now
                inputs = x_batch

                outputs = net(inputs)

                if k == 0 and i == 0 and wandb_project is not None:
                    image = x_batch[0, 0, :]
                    wandb_log_dict.update({'Example Image': wandb.Image(image)})

                Risk_batch = loss(outputs, y_batch)

                Risk_batch.backward()
                optimizer.step()

                # Now I need to compute the accuracy Let us start by computing the accuracy on the
                # training set of the current task
                with torch.no_grad():
                    if iteration % sample_freq == 0:
                        x_tr, y_tr = train[t]
                        x_tr_batches = torch.split(x_tr, test_batch, dim=0)
                        y_tr_batches = torch.split(y_tr, test_batch, dim=0)
                        acc_train = 0.
                        for _b in range(len(x_tr_batches)):
                            x_tr_batch = x_tr_batches[_b]
                            y_tr_batch = y_tr_batches[_b]
                            x_tr_batch, y_tr_batch = x_tr_batch.to(device), y_tr_batch.to(device)
                            o_tr_batch = net(x_tr_batch)
                            acc_train += compute_acc(o_tr_batch, y_tr_batch)

                        acc_train /= (len(x_tr_batches))

                        x_te, y_te = test[t]
                        x_te_batches = torch.split(x_te, test_batch, dim=0)
                        y_te_batches = torch.split(y_te, test_batch, dim=0)
                        acc_test = 0.
                        for _b in range(len(x_te_batches)):
                            x_te_batch = x_te_batches[_b]
                            y_te_batch = y_te_batches[_b]
                            x_te_batch, y_te_batch = x_te_batch.to(device), y_te_batch.to(device)
                            o_te_batch = net(x_te_batch)
                            acc_test += compute_acc(o_te_batch, y_te_batch)

                        acc_test /= (len(x_te_batches))

                        if wandb_project is not None:
                            wandb_log_dict.update({'acc_train': acc_train})
                            wandb_log_dict.update({'acc_test': acc_test})

                        # Finally I compute the cumulative loss
                        # Since here the dimension of the test set is always the same (10000 samples) I can compute
                        # the accuracy for each task and then average the result over the number of task to compute the
                        # cumulative accuracy
                        cml_acc = 0.
                        cml_acc_hippo = 0.
                        for m in range(t):
                            X, Y = test[m]
                            ####
                            X_batches = torch.split(X, test_batch, dim=0)
                            Y_batches = torch.split(Y, test_batch, dim=0)
                            local_acc = 0.
                            if hippo:
                                local_acc_hippo = 0.
                            for _b in range(len(X_batches)):
                                X_batch = X_batches[_b]
                                Y_batch = Y_batches[_b]
                                X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
                                if hippo:
                                    saved_weights = get_learnable_parameters_as_vector(net)
                                    w_approx = HP.compute_approximation(
                                        (m + 1) * epochs * tau)  # TODO: check if this is correct
                                    set_learnable_parameters_from_vector(net, w_approx)
                                    out_hippo = net(X_batch)
                                    set_learnable_parameters_from_vector(net, saved_weights)
                                out = net(X_batch)
                                local_acc += compute_acc(out, Y_batch)
                                if hippo:
                                    local_acc_hippo += compute_acc(out_hippo, Y_batch)

                            local_acc /= (len(X_batches))
                            cml_acc += local_acc
                            if hippo:
                                local_acc_hippo /= (len(X_batches))
                                cml_acc_hippo += local_acc_hippo
                        cml_acc += acc_test
                        cml_acc /= t+1
                        if hippo:
                            cml_acc_hippo += acc_test
                            cml_acc_hippo /= t + 1
                        if wandb_project is not None:
                            wandb_log_dict.update({'cml_test_acc': cml_acc})
                            if hippo:
                                wandb_log_dict.update({'cml_test_acc_hippo': cml_acc_hippo})

                        if wandb_project is not None:
                            wandb.log(wandb_log_dict, step=iteration)

                        if hippo:
                            print('step: '+str(iteration)+' task: '+str(t)+' acc_train: '+str(acc_train)+' acc_test: '+\
                                str(acc_test)+' acc_cml: '+str(cml_acc)+' cml_test_acc_hippo: '+str(cml_acc_hippo))
                        else:
                            print('step: ' + str(iteration) + ' task: ' + str(t) + ' acc_train: ' + str(
                                acc_train) + ' acc_test: ' + str(acc_test) + ' acc_cml: ' + str(cml_acc))

            if hippo:
                with torch.no_grad():
                    HP.update_coefficients(get_learnable_parameters_as_vector(net))



















