"""
Regression experiment using CAVIA
"""
import copy
import os
import time

import numpy as np
import scipy.stats as st
import torch
import torch.nn.functional as F
import torch.optim as optim

import datetime
import wandb

import utils
import tasks_sine, tasks_celebA
from cavia_model import CaviaModel
from logger import Logger


def run(args, log_interval=5000, rerun=False):
    assert args.algo == args.algo in ['cavia', 'pn_cavia', 'cn_cavia']

    current_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    wandb.init(project='Sine Regression')
    wandb.run.name = current_time
    wandb.run.save()

    wandb.config.update(args)
    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    path = '{}/{}_result_files/'.format(code_root, args.task)

    path = os.path.join(path, args.algo, f'n_param={args.num_context_params}', f'seed={args.seed}')
    if args.algo == 'pn_cavia':
        path = os.path.join(path, f'{args.coef}')
    elif args.algo == 'cn_cavia':
        path = os.path.join(path, f'{args.radius}')
    else:
        if args.algo != 'cavia':
            raise NotImplementedError
        
    if not os.path.exists(path):
        os.makedirs(path)
        
    path = os.path.join(path, current_time)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train', device=args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid', device=args.device)
        task_family_test = tasks_celebA.CelebADataset('test', device=args.device)
    else:
        raise NotImplementedError

    # initialise network
    model = CaviaModel(n_in=task_family_train.num_inputs,
                       n_out=task_family_train.num_outputs,
                       num_context_params=args.num_context_params,
                       n_hidden=args.num_hidden_layers,
                       device=args.device
                       ).to(args.device)

    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    meta_optimiser = optim.Adam(model.parameters(), args.lr_meta)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):

        # initialise meta-gradient
        meta_gradient = [0 for _ in range(len(model.state_dict()))]

        # sample tasks
        target_functions = task_family_train.sample_tasks(args.tasks_per_metaupdate)

        # --- inner loop ---

        if args.algo == 'cavia':
            for t in range(args.tasks_per_metaupdate):

                # reset private network weights
                model.reset_context_params()

                # get data for current task
                train_inputs = task_family_train.sample_inputs(args.k_meta_train, args.use_ordered_pixels).to(args.device)

                for _ in range(args.num_inner_updates):
                    # forward through model
                    train_outputs = model(train_inputs)

                    # get targets
                    train_targets = target_functions[t](train_inputs)

                    # ------------ update on current task ------------

                    # compute loss for current task
                    task_loss = F.mse_loss(train_outputs, train_targets)

                    # compute gradient wrt context params
                    task_gradients = \
                        torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]

                    # update context params (this will set up the computation graph correctly)
                    model.context_params = model.context_params - args.lr_inner * task_gradients

                # ------------ compute meta-gradient on test loss of current task ------------

                # get test data
                test_inputs = task_family_train.sample_inputs(args.k_meta_test, args.use_ordered_pixels).to(args.device)

                # get outputs after update
                test_outputs = model(test_inputs)

                # get the correct targets
                test_targets = target_functions[t](test_inputs)

                # compute loss after updating context (will backprop through inner loop)
                loss_meta = F.mse_loss(test_outputs, test_targets)

                # compute gradient + save for current task
                task_grad = torch.autograd.grad(loss_meta, model.parameters())

                for i in range(len(task_grad)):
                    # clip the gradient
                    meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)
        elif args.algo == 'pn_cavia':
            batch_context_params = [torch.zeros(args.num_context_params, requires_grad=True).to(args.device) for _ in range(args.tasks_per_metaupdate)]
            train_inputs = []
            for _ in range(args.tasks_per_metaupdate):
                train_inputs.append(task_family_train.sample_inputs(args.k_meta_train, args.use_ordered_pixels).to(args.device))
            
            for i in range(args.num_inner_updates):
                batch_context_params_temp = [torch.zeros(args.num_context_params).to(args.device) for _ in range(args.tasks_per_metaupdate)]

                average_context_params = torch.mean(torch.stack(batch_context_params, dim=0), dim=0)

                for t in range(args.tasks_per_metaupdate):
                    model.context_params = batch_context_params[t]

                    train_outputs = model(train_inputs[t])
                    train_targets = target_functions[t](train_inputs[t])
                    task_loss_train = F.mse_loss(train_outputs, train_targets)
                    regu_loss_train = 0.5 * args.coef * torch.sum(average_context_params ** 2)
                    loss_train = task_loss_train + regu_loss_train
                    # compute gradient wrt context params
                    task_gradients = \
                        torch.autograd.grad(loss_train, batch_context_params[t], create_graph=not args.first_order)[0]
                    
                    # update context params (this will set up the computation graph correctly)
                    batch_context_params_temp[t] = batch_context_params[t] - args.lr_inner * task_gradients
                
                batch_context_params = batch_context_params_temp
                # ------------ compute meta-gradient on test loss of current task ------------
            
            for t in range(args.tasks_per_metaupdate):
                model.reset_context_params()
                # Load context paramerters for each task
                model.context_params = batch_context_params[t]
                # get test data
                test_inputs = task_family_train.sample_inputs(args.k_meta_test, args.use_ordered_pixels).to(args.device)

                # get outputs after update
                test_outputs = model(test_inputs)

                # get the correct targets
                test_targets = target_functions[t](test_inputs)

                # compute loss after updating context (will backprop through inner loop)
                loss_meta = F.mse_loss(test_outputs, test_targets)

                # compute gradient + save for current task
                task_grad = torch.autograd.grad(loss_meta, model.parameters(), retain_graph=True)

                for i in range(len(task_grad)):
                    # clip the gradient
                    meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)
        elif args.algo == 'cn_cavia':
            batch_context_params = [torch.zeros(args.num_context_params, requires_grad=True).to(args.device) for _ in range(args.tasks_per_metaupdate)]
            train_inputs = []
            for _ in range(args.tasks_per_metaupdate):
                train_inputs.append(task_family_train.sample_inputs(args.k_meta_train, args.use_ordered_pixels).to(args.device))
            
            distance_list = []

            for i in range(args.num_inner_updates):
                batch_context_params_temp = [torch.zeros(args.num_context_params).to(args.device) for _ in range(args.tasks_per_metaupdate)]

                for t in range(args.tasks_per_metaupdate):
                    model.context_params = batch_context_params[t]

                    train_outputs = model(train_inputs[t])
                    train_targets = target_functions[t](train_inputs[t])
                    task_loss_train = F.mse_loss(train_outputs, train_targets)
                    
                    # compute gradient wrt context params
                    task_gradients = \
                        torch.autograd.grad(task_loss_train, batch_context_params[t], create_graph=not args.first_order)[0]
                    
                    # update context params (this will set up the computation graph correctly)
                    batch_context_params_temp[t] = batch_context_params[t] - args.lr_inner * task_gradients
                
                dist_square_sum = sum([(context_params.norm(2) ** 2) for context_params in batch_context_params_temp])
                d = torch.sqrt(dist_square_sum)
                distance_list.append(d)

                if d > args.radius:
                    ratio = args.radius/d
                else:
                    ratio = 1
                batch_context_params = [ratio * context_params for context_params in batch_context_params_temp]
                
            for t in range(args.tasks_per_metaupdate):
                # ------------ compute meta-gradient on test loss of current task ------------
                model.reset_context_params()
                # Load context paramerters for each task
                model.context_params = batch_context_params[t]
                # get test data
                test_inputs = task_family_train.sample_inputs(args.k_meta_test, args.use_ordered_pixels).to(args.device)

                # get outputs after update
                test_outputs = model(test_inputs)

                # get the correct targets
                test_targets = target_functions[t](test_inputs)

                # compute loss after updating context (will backprop through inner loop)
                loss_meta = F.mse_loss(test_outputs, test_targets)

                # compute gradient + save for current task
                task_grad = torch.autograd.grad(loss_meta, model.parameters(), retain_graph=True)

                for i in range(len(task_grad)):
                    # clip the gradient
                    meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)
        else:
            raise NotImplementedError

        # ------------ meta update ------------

        # assign meta-gradient
        for i, param in enumerate(model.parameters()):
            param.grad = meta_gradient[i] / args.tasks_per_metaupdate

        # do update step on shared model
        meta_optimiser.step()

        # reset context params
        model.reset_context_params()

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_train,
                                              num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_valid,
                                              num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_test,
                                              num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(model)

            # visualise results
            if args.task == 'celeba':
                task_family_train.visualise(task_family_train, task_family_test, copy.deepcopy(logger.best_valid_model),
                                            args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            if i_iter != 0:
                logger.wandb_log(i_iter, start_time)
            start_time = time.time()

    return logger


def eval_cavia(args, model, task_family, num_updates, n_tasks=100, return_gradnorm=False):
    # get the task family
    input_range = task_family.get_input_range().to(args.device)

    # logging
    losses = []
    gradnorms = []

    # --- inner loop ---

    for t in range(n_tasks):

        # sample a task
        target_function = task_family.sample_task()

        # reset context parameters
        model.reset_context_params()

        # get data for current task
        curr_inputs = task_family.sample_inputs(args.k_shot_eval, args.use_ordered_pixels).to(args.device)
        curr_targets = target_function(curr_inputs)

        # ------------ update on current task ------------

        for _ in range(1, num_updates + 1):

            # forward pass
            curr_outputs = model(curr_inputs)

            # compute loss for current task
            task_loss = F.mse_loss(curr_outputs, curr_targets)

            # compute gradient wrt context params
            task_gradients = \
                torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]

            # update context params
            if args.first_order:
                model.context_params = model.context_params - args.lr_inner * task_gradients.detach()
            else:
                model.context_params = model.context_params - args.lr_inner * task_gradients

            # keep track of gradient norms
            gradnorms.append(task_gradients[0].norm().item())

        # ------------ logging ------------

        # compute true loss on entire input range
        model.eval()
        losses.append(F.mse_loss(model(input_range), target_function(input_range)).detach().item())
        model.train()

    losses_mean = np.mean(losses)
    losses_conf = st.t.interval(0.95, len(losses) - 1, loc=losses_mean, scale=st.sem(losses))
    if not return_gradnorm:
        return losses_mean, np.mean(np.abs(losses_conf - losses_mean))
    else:
        return losses_mean, np.mean(np.abs(losses_conf - losses_mean)), np.mean(gradnorms)
