"""
Regression experiment using MAML
"""
import copy

import dgl
import numpy as np
import scipy.stats as st
import torch
import torch.nn.functional as F
import torch.optim as optim

from codes.logbook.logbook import LogBook

# from codes.model.dgl.classifiers import RelationGATMAMLClassifier as Model
from codes.model.pyg.edge_gat import GatEncoder as Model
from codes.model.meta_learning.utils import get_input_and_target, bootstrap_task_family
from codes.utils.config import get_config
from codes.utils.data import GraphBatch
from tqdm import tqdm
import argparse
import time


def run(config, logbook):
    # skip outer calculation for the following :
    skip_gradients = ["node_embedding"]
    if config.model.gat.share_relation_emb:
        skip_gradients.append("relation_embedding")
    # task_family_train, task_family_valid, task_family_test = bootstrap(config=config)

    task_families = bootstrap_task_family(config=config)
    num_classes = max([t.num_classes for t in task_families])
    task_family_train = task_families[0]
    task_family_valid = task_families[1]
    task_family_test = task_families[2]
    config.model.num_classes = num_classes
    step = 0

    # initialise network
    model_inner = Model(config=config).to(config.general.device)
    model_outer = copy.deepcopy(model_inner)

    # intitialise meta-optimiser
    meta_optimiser = optim.Adam(model_outer.weights, config.model.optim.learning_rate)

    # initialise loggers
    log_interval = config.logger.remote.frequency
    best_valid_model = copy.deepcopy(model_outer)
    best_val_loss = np.inf

    pb_e = tqdm(total=config.model.num_epochs)
    for i_iter in range(config.model.num_epochs):

        # copy weights of network
        copy_weights = [w.clone() for w in model_outer.weights]

        # get all shared parameters and initialise cumulative gradient
        meta_gradient = [0 for _ in range(len(copy_weights) + 1)]

        # sample tasks
        pb_t = tqdm(total=config.model.tasks_per_metaupdate)
        for t in range(config.model.tasks_per_metaupdate):
            # reset network weights
            # model_inner
            model_inner.weights = [w.clone() for w in copy_weights]

            # get data for current task
            graphs, queries, targets, target_function = get_input_and_target(
                config=config,
                batch_size=config.general.batch_size,
                task_family=task_family_train,
                target_function=None,
                mode="train",
            )

            batch = GraphBatch(graphs=graphs, queries=queries, targets=targets)
            batch.to(config.general.device)

            pb_i = tqdm(total=config.model.num_inner_updates)
            # inner updates - minibatches for each task
            for _ in range(config.model.num_inner_updates):
                # forward through network
                # outputs = model_inner(graphs, queries)
                outputs = model_inner(batch)

                # ------------ update on current task ------------

                # compute loss for current task
                loss_task = F.cross_entropy(outputs, batch.targets)

                # compute the gradient wrt current model
                grads = torch.autograd.grad(
                    loss_task, model_inner.weights, create_graph=True, retain_graph=True
                )

                # make an update on the inner model using the current model (to build up computation graph)
                for pi in range(len(model_inner.weights)):
                    if not config.model.first_order:
                        model_inner.weights[pi] = model_inner.weights[
                            pi
                        ] - config.model.lr_inner * grads[pi].clamp_(-5, 5)
                    else:
                        model_inner.weights[pi] = model_inner.weights[
                            pi
                        ] - config.model.lr_inner * grads[pi].detach().clamp_(-5, 5)
                pb_i.set_description("Inner loss : {}".format(loss_task.item()))
                pb_i.update(1)
                # # get new batch of data
                # graphs, queries, targets, target_function = get_input_and_target(
                #     config=config,
                #     batch_size=config.general.batch_size,
                #     task_family=task_family_train,
                #     target_function=target_function,
                #     mode="train")
                #
                # batch = GraphBatch(graphs=graphs, queries=queries, targets=targets)
                # batch.to(config.general.device)

            # ------------ compute meta-gradient on test loss of current task ------------

            pb_i.close()
            # get test data
            graphs, queries, targets, _ = get_input_and_target(
                config=config,
                batch_size=config.general.batch_size,  # config.model.k_meta_test,
                task_family=task_family_train,
                target_function=target_function,
                mode="valid",
            )

            batch = GraphBatch(graphs=graphs, queries=queries, targets=targets)
            batch.to(config.general.device)

            # get outputs after update
            test_outputs = model_inner(batch)

            # referencing doesn't work now
            # dot = make_dot(test_outputs, params={model_inner.weight_names[i]: w for i,w in enumerate(model_inner.weights)})
            # dot.format = 'png'
            # dot.render('maml_graph_single.png')
            # compute loss (will backprop through inner loop)
            loss_meta = F.cross_entropy(test_outputs, batch.targets)
            # compute gradient w.r.t. *outer model*
            task_grads = torch.autograd.grad(loss_meta, model_outer.weights)
            for i in range(len(model_inner.weights)):
                meta_gradient[i] += task_grads[i].detach()

            pb_t.set_description("Meta loss : {}".format(loss_meta.item()))
            pb_t.update(1)

        pb_t.close()
        # ------------ meta update ------------

        meta_optimiser.zero_grad()
        # print(meta_gradient)

        # assign meta-gradient
        for i in range(len(model_outer.weights)):
            if model_outer.weight_names[i] in skip_gradients:
                continue
            model_outer.weights[i].grad = (
                meta_gradient[i] / config.model.tasks_per_metaupdate
            )
            meta_gradient[i] = 0

        # do update step on outer model
        meta_optimiser.step()
        pb_e.update(1)

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf, acc = eval(
                config,
                copy.copy(model_outer),
                task_family=task_family_train,
                num_updates=config.model.num_inner_updates,
            )

            metric_dict = {
                "loss_mean": loss_mean,
                "loss_conf": loss_conf,
                "mode": "train",
                "steps": i_iter,
                "epoch_idx": i_iter,
                "accuracy": acc,
                "minibatch": step,
            }
            logbook.write_metric_logs(metric_dict)
            step += 1

            # evaluate on test set
            loss_mean, loss_conf, acc = eval(
                config,
                copy.copy(model_outer),
                task_family=task_family_valid,
                num_updates=config.model.num_inner_updates,
            )
            metric_dict = {
                "loss_mean": loss_mean,
                "loss_conf": loss_conf,
                "mode": "valid",
                "steps": i_iter,
                "epoch_idx": i_iter,
                "accuracy": acc,
                "minibatch": step,
            }
            logbook.write_metric_logs(metric_dict)
            step += 1

            # evaluate on validation set
            loss_mean, loss_conf, acc = eval(
                config,
                copy.copy(model_outer),
                task_family=task_family_test,
                num_updates=config.model.num_inner_updates,
            )

            metric_dict = {
                "loss_mean": loss_mean,
                "loss_conf": loss_conf,
                "mode": "test",
                "steps": i_iter,
                "epoch_idx": i_iter,
                "accuracy": acc,
                "minibatch": step,
            }
            logbook.write_metric_logs(metric_dict)
            step += 1
    pb_e.close()


def eval(config, model, task_family, num_updates, n_tasks=100, return_gradnorm=False):
    # copy weights of network
    copy_weights = [w.clone() for w in model.weights]

    # logging
    losses = []
    acc = []
    gradnorms = []

    # --- inner loop ---
    n_tasks = len(task_family.graphworld_list)
    for task_id in range(n_tasks):

        # reset network weights
        model.weights = [w.clone() for w in copy_weights]

        graphs, queries, targets, target_function = get_input_and_target(
            config=config,
            batch_size=config.general.batch_size,  # 4,
            task_family=task_family,
            target_function=None,
            mode="train",
            task_id=task_id,
        )

        batch = GraphBatch(graphs=graphs, queries=queries, targets=targets)
        batch.to(config.general.device)

        # ------------ update on current task ------------

        for _ in range(1, num_updates + 1):

            outputs = model(batch)

            # compute loss for current task
            task_loss = F.cross_entropy(outputs, batch.targets)

            # update task parameters
            grads = torch.autograd.grad(task_loss, model.weights)

            gradnorms.append(np.mean(np.array([g.norm().item() for g in grads])))

            for i in range(len(model.weights)):
                model.weights[i] = (
                    model.weights[i] - config.model.lr_inner * grads[i].detach()
                )

        # ------------ logging ------------

        # get the task family (with infinite number of tasks)
        input_range = task_family.get_input_range(mode="test")
        # compute true loss on entire input range
        graphs, queries, indices = input_range
        for bi in range(0, len(graphs), config.general.batch_size):
            b_graphs = graphs[bi : bi + config.general.batch_size]
            b_queries = queries[bi : bi + config.general.batch_size]
            b_indices = indices[bi : bi + config.general.batch_size]
            batch = GraphBatch(b_graphs, b_queries, target_function(b_indices, "test"))
            batch.to(config.general.device)
            logits = model(batch)
            losses.append(F.cross_entropy(logits, batch.targets).detach().item())
            # calculate accuracy
            predictions, conf = model.predict(logits)
            acc.append(model.accuracy(predictions, batch.targets).cpu().detach().item())

    # reset network weights
    model.weights = [w.clone() for w in copy_weights]

    losses_mean = np.mean(losses)
    losses_conf = st.t.interval(
        0.95, len(losses) - 1, loc=losses_mean, scale=st.sem(losses)
    )
    acc = np.mean(acc)

    if not return_gradnorm:
        return losses_mean, np.mean(np.abs(losses_conf - losses_mean)), acc
    else:
        return (
            losses_mean,
            np.mean(np.abs(losses_conf - losses_mean)),
            np.mean(gradnorms),
            acc,
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_id", default="debug", type=str, help="config id")
    args = parser.parse_args()
    config = get_config(args.config_id)
    logbook = LogBook(config=config)
    run(config, logbook)
