"""Class to run the MamlSignatureExperiment"""
# from time import time

import argparse
import copy
import os
import pickle as pkl
import random
import time
from typing import Iterable, Optional, Tuple

import numpy as np
import scipy.stats as st
import torch
import torch.nn.functional as F
import torch.optim as optim
from box import Box
from ipdb import set_trace
from torch import nn
from tqdm import tqdm

from codes.experiment.checkpointable_experiment import CheckpointableExperiment
from codes.logbook.logbook import LogBook
from codes.model.meta_learning.utils import bootstrap_task_family, get_input_and_target

# from codes.model.dgl.classifiers import RelationGATMAMLClassifier as Model
# from codes.model.gat.edge_gat import GatedGatEncoder as Model
from codes.model.models import RepresentationFn, CompositionFn
from codes.utils.checkpointable import Checkpointable
from codes.utils.config import get_config
from codes.utils.data import GraphBatch, graph_batch_iterator
from codes.utils.util import make_dir

TRAIN_MODE = True


class MetaExperiment(Checkpointable):
    """Experiment Class for MAML 
    """

    def __init__(self, config: Box, logbook: LogBook):
        self.config = config
        self.logbook = logbook
        self.load_data()
        (self.model_outer, self.signature_model) = self.bootstrap_model()
        self.epoch_to_start_from = 0

    def bootstrap_model(self) -> [nn.Module, nn.Module]:
        model_outer = CompositionFn(self.config)
        signature_model = RepresentationFn(self.config)
        return model_outer, signature_model

    def register_optim_sched(
        self, skip_composition_registry=False, skip_representation_registry=False
    ):
        if skip_composition_registry:
            self.meta_optimiser = optim.Adam(
                self.signature_model.weights,
                self.config.model.optim.learning_rate,
                weight_decay=self.config.model.optim.weight_decay,
            )
        elif skip_representation_registry:
            self.meta_optimiser = optim.Adam(
                self.model_outer.weights,
                self.config.model.optim.learning_rate,
                weight_decay=self.config.model.optim.weight_decay,
            )
        else:
            self.meta_optimiser = optim.Adam(
                self.model_outer.weights + self.signature_model.weights,
                self.config.model.optim.learning_rate,
                weight_decay=self.config.model.optim.weight_decay,
            )

    # def build_model(self) -> nn.Module:
    #     model = Model(config=self.config).to(self.config.general.device)
    #     model.model_config_key = "gat"
    #     return model

    def load_data(self) -> None:
        """ Load the task families
        """
        task_families = bootstrap_task_family(config=self.config)
        num_classes = max([t.num_classes for t in task_families])
        self.task_family_train = task_families[0]
        self.task_family_valid = task_families[1]
        self.task_family_test = task_families[2]
        self.config.model.num_classes = num_classes

    def run(self):
        # skip outer calculation for the following :
        skip_gradients = ["node_embedding"]
        skip_gradients.append("relation_embedding")
        # task_family_train, task_family_valid, task_family_test = bootstrap(config=config)
        # todo-correct registry
        if (
            "use_composition_fn" in self.config.model
            and self.config.model.use_composition_fn
        ):
            self.load_only_composition()
        if (
            "use_representation_fn" in self.config.model
            and self.config.model.use_representation_fn
        ):
            self.load_only_representation()
        if self.config.model.freeze_composition_fn:
            self.model_outer.freeze_weights()
        if self.config.model.freeze_representation_fn:
            self.signature_model.freeze_weights()

        # register optimizer
        self.register_optim_sched(
            skip_composition_registry=self.config.model.freeze_composition_fn,
            skip_representation_registry=self.config.model.freeze_representation_fn,
        )
        step = 0

        # initialise network
        model_inner = copy.deepcopy(self.model_outer)
        signature_model_inner = copy.deepcopy(self.signature_model)
        model_inner.train()
        signature_model_inner.train()
        self.model_outer.train()
        self.signature_model.train()

        # initialise loggers
        log_interval = self.config.logger.remote.frequency
        best_valid_model = copy.deepcopy(self.model_outer)
        best_val_loss = np.inf

        # pb_e = tqdm(total=self.config.model.num_epochs - self.epoch_to_start_from)

        for i_iter in range(self.epoch_to_start_from, self.config.model.num_epochs):
            ## ------ TRAINING ----------
            # copy weights of network
            copy_weights = [w.clone() for w in self.model_outer.weights]
            copy_signature_weights = [w.clone() for w in self.signature_model.weights]
            # get all shared parameters and initialise cumulative gradient
            meta_gradient = [
                0 for _ in range(len(copy_weights) + len(copy_signature_weights))
            ]

            # for logging
            inner_loss = []
            meta_loss = []
            # sample tasks
            # pb_t = tqdm(total=self.config.model.tasks_per_metaupdate)
            for _ in range(self.config.model.tasks_per_metaupdate):
                # get data for current task
                (
                    graphs,
                    queries,
                    targets,
                    target_function,
                    world_graph,
                ) = get_input_and_target(
                    config=self.config,
                    batch_size=self.config.general.batch_size,
                    task_family=self.task_family_train,
                    target_function=None,
                    mode="train",
                )

                bt = 0
                loss_metas = []
                task_inner_loss = []

                batch = GraphBatch(
                    graphs=graphs,
                    queries=queries,
                    targets=targets,
                    world_graphs=[world_graph],
                )
                # for batch in graph_batch_iterator(
                #     graphs=graphs,
                #     queries=queries,
                #     targets=targets,
                #     world_graphs=[world_graph],
                # ):
                model_inner.set_weights([w.clone() for w in copy_weights])
                signature_model_inner.set_weights(
                    [w.clone() for w in copy_signature_weights]
                )
                batch.to(self.config.general.device)
                # get relation embeddings
                # rel_emb, _ = self.signature_model(batch)
                # reset network weights
                # model_inner
                # model_inner.weights = [w.clone() for w in copy_weights]
                # rel_emb_index = model_inner.weight_names.index("relation_embedding")
                # model_inner.weights[rel_emb_index] = rel_emb.clone()
                # rel_emb, _ = self.signature_model(batch)

                # pb_i = tqdm(total=self.config.model.num_inner_updates)
                # inner updates - minibatches for each task
                for _ in range(self.config.model.num_inner_updates):
                    # forward through network
                    # outputs = model_inner(graphs, queries)
                    rel_emb, _ = signature_model_inner(batch)
                    outputs = model_inner(batch, rel_emb)

                    # ------------ update on current task ------------

                    # compute loss for current task
                    loss_task = F.cross_entropy(outputs, batch.targets)
                    # loss_task = F.nll_loss(outputs, batch.targets)
                    # additionally, add l2 regularization
                    params = [w for w in model_inner.weights if w.requires_grad] + [
                        w for w in signature_model_inner.weights if w.requires_grad
                    ]
                    l2_reg = (
                        sum([torch.norm(param) for param in params])
                        * self.config.model.optim.inner_weight_decay
                    )
                    loss_task += l2_reg
                    # compute the gradient wrt current model
                    grads = torch.autograd.grad(
                        loss_task,
                        params,
                        create_graph=True,
                        retain_graph=True,
                        allow_unused=True,
                    )

                    # make an update on the inner model using the current model (to build up computation graph)
                    grads_id = 0
                    weights = model_inner.weights
                    for pi in range(len(model_inner.weights)):
                        if not weights[pi].requires_grad:
                            continue
                        if not self.config.model.first_order:
                            weights[pi] = weights[
                                pi
                            ] - self.config.model.lr_inner * grads[grads_id].clamp_(
                                -self.config.model.clamp, self.config.model.clamp
                            )
                        else:
                            weights[pi] = weights[
                                pi
                            ] - self.config.model.lr_inner * grads[
                                grads_id
                            ].detach().clamp_(
                                -self.config.model.clamp, self.config.model.clamp
                            )
                        grads_id += 1
                    model_inner.set_weights(weights)
                    weights = signature_model_inner.weights
                    pi += 1
                    for pj in range(len(signature_model_inner.weights)):
                        if not weights[pj].requires_grad:
                            continue
                        if not self.config.model.first_order:
                            weights[pj] = weights[
                                pj
                            ] - self.config.model.lr_inner * grads[grads_id].clamp_(
                                -self.config.model.clamp, self.config.model.clamp
                            )
                        else:
                            weights[pj] = weights[
                                pj
                            ] - self.config.model.lr_inner * grads[
                                grads_id
                            ].detach().clamp_(
                                -self.config.model.clamp, self.config.model.clamp
                            )
                        grads_id += 1
                    signature_model_inner.set_weights(weights)
                    # pb_i.set_description("Inner loss : {}".format(loss_task.item()))
                    # pb_i.update(1)
                    task_inner_loss.append(loss_task.item())
                    # # get new batch of data
                    # graphs, queries, targets, target_function = get_input_and_target(
                    #     config=config,
                    #     batch_size=config.general.batch_size,
                    #     task_family=task_family_train,
                    #     target_function=target_function,
                    #     mode="train")
                    #
                    # batch = GraphBatch(graphs=graphs, queries=queries, targets=targets)
                    # batch.to(config.general.device)

                # ------------ compute meta-gradient on test loss of current task ------------

                # pb_i.close()
                # get test data
                (
                    graphs,
                    queries,
                    targets,
                    target_function,
                    world_graph,
                ) = get_input_and_target(
                    config=self.config,
                    batch_size=self.config.general.batch_size,  # config.model.k_meta_test,
                    task_family=self.task_family_train,
                    target_function=target_function,
                    mode="valid",
                )
                batch = GraphBatch(
                    graphs=graphs,
                    queries=queries,
                    targets=targets,
                    world_graphs=[world_graph],
                )
                batch.to(self.config.general.device)

                # get outputs after update
                # test_rel, _ = self.signature_model(batch)
                test_rel, _ = signature_model_inner(batch)
                test_outputs = model_inner(batch, test_rel)

                # referencing doesn't work now
                # dot = make_dot(test_outputs, params={model_inner.weight_names[i]: w for i,w in enumerate(model_inner.weights)})
                # dot.format = 'png'
                # dot.render('maml_graph_single.png')
                # compute loss (will backprop through inner loop)
                loss_meta = F.cross_entropy(test_outputs, batch.targets)
                # loss_meta = F.nll_loss(test_outputs, batch.targets)
                # compute gradient w.r.t. *outer model*
                # import ipdb; ipdb.set_trace()
                # dldb = torch.autograd.grad(
                #     test_outputs[0][0],
                #     self.model_outer.weights[-1], # + self.signature_model.weights,
                #     allow_unused=False, retain_graph=True
                # )
                params = [w for w in self.model_outer.weights if w.requires_grad] + [
                    w for w in self.signature_model.weights if w.requires_grad
                ]
                task_grads = torch.autograd.grad(loss_meta, params, allow_unused=False)
                # # check if any grads is None
                # param_names = model_outer.weight_names + signature_model.weight_names
                # null_names = [param_names[ti] for ti,t in enumerate(task_grads) if t is None]
                # if len(null_names) > 0:
                #     print(null_names)
                #     raise AssertionError("some gradients are null")
                task_grads = [0 if grad is None else grad for grad in task_grads]
                task_id = 0
                all_weights = self.model_outer.weights + self.signature_model.weights
                for i in range(len(all_weights)):
                    if not all_weights[i].requires_grad:
                        continue
                    if type(task_grads[task_id]) == int:
                        meta_gradient[i] = 0
                    else:
                        meta_gradient[i] += (
                            task_grads[task_id]
                            .detach()
                            .clamp_(-self.config.model.clamp, self.config.model.clamp)
                        )
                    task_id += 1

                loss_metas.append(loss_meta.item())
                # pb_t.set_description(
                #     "Meta loss : {}, B: {}".format(np.mean(loss_metas), bt)
                # )
                meta_loss.append(np.mean(loss_metas))
                bt += 1

                inner_loss.append(np.mean(task_inner_loss))
                # pb_t.update(1)

            # pb_t.close()
            # ------------ meta update ------------
            # set_trace()

            self.meta_optimiser.zero_grad()
            # print(meta_gradient)

            # assign meta-gradient
            weights = self.model_outer.weights
            for i in range(len(self.model_outer.weights)):
                if self.model_outer.weight_names[i] in skip_gradients:
                    continue
                if not weights[i].requires_grad:
                    continue
                weights[i].grad = (
                    meta_gradient[i] / self.config.model.tasks_per_metaupdate
                )
                meta_gradient[i] = 0
            self.model_outer.set_weights(weights)
            # update for signature fn
            weights = self.signature_model.weights
            for i in range(len(self.signature_model.weights)):
                t = i + len(self.model_outer.weights)
                if not weights[i].requires_grad:
                    continue
                weights[i].grad = (
                    meta_gradient[t] / self.config.model.tasks_per_metaupdate
                )
                meta_gradient[t] = 0
            self.signature_model.set_weights(weights)

            # do update step on outer model
            self.meta_optimiser.step()
            # set_trace()

            # ------------ logging ------------
            metric_dict = {
                "loss_inner": np.mean(inner_loss),
                "loss_meta": np.mean(meta_loss),
                "mode": "train",
                "steps": i_iter,
                "epoch_idx": i_iter,
                "minibatch": step,
            }
            self.logbook.write_metric_logs(metric_dict)
            step += 1

            # ------ save models --------------
            # model_outer.save(i_iter, [meta_optimiser])
            # signature_model.save(i_iter, [meta_optimiser])
            self.periodic_save(epoch=i_iter)
            ## eval
            # Disabling eval
            # if i_iter % 50 == 0:
            #     loss_mean, loss_conf, acc = eval(
            #         self.config,
            #         copy.copy(self.model_outer),
            #         copy.copy(self.signature_model),
            #         task_family=self.task_family_valid,
            #         num_updates=self.config.model.num_inner_updates,
            #     )
            #     metric_dict = {
            #         "loss_mean": loss_mean,
            #         "loss_conf": loss_conf,
            #         "mode": "valid",
            #         "steps": i_iter,
            #         "epoch_idx": i_iter,
            #         "accuracy": acc,
            #         "minibatch": step,
            #     }
            #     self.logbook.write_metric_logs(metric_dict)
            # pb_e.update(1)
            torch.cuda.empty_cache()

        # pb_e.close()

    def periodic_save(self, epoch: int):
        if self.config.model.persist_frequency > 0:
            if epoch % self.config.model.persist_frequency == 0:
                self.save_model(epoch=epoch)

    def save_model(self, epoch: Optional[int] = None) -> None:
        if epoch is None:
            epoch = 0
        self.save(
            self.config.model.save_dir,
            epoch,
            self.model_outer,
            self.signature_model,
            [self.meta_optimiser],
        )

    def load_model(self, epoch=None) -> Tuple[int, nn.Module]:
        model_outer = self.model_outer
        signature_model = self.signature_model
        # meta_optimiser = self.meta_optimiser
        model_outer, signature_model, optimizers, epoch = self.load(
            self.config.model.save_dir, epoch, model_outer, signature_model
        )

        self.model_outer = model_outer
        self.signature_model = signature_model
        # self.meta_optimiser = meta_optimiser
        self.epoch_to_start_from = epoch if epoch else self.epoch_to_start_from
        # return epoch, model_outer, signature_model, meta_optimiser

    def load_only_composition(self, epoch: Optional[int] = None) -> None:
        """Load only composition model
            epoch {Optional[int]} -- [description] (default: {None})
        """
        self.model_outer, _, _, _ = self.load(
            self.config.model.load_dir, epoch, self.model_outer, None, None
        )

    def load_only_representation(self, epoch: Optional[int] = None) -> None:
        """Load only representation model
            epoch {Optional[int]} -- [description] (default: {None})
        """
        _, self.signature_model, _, _ = self.load(
            self.config.model.load_dir, epoch, representation_fn=self.signature_model
        )

    def evaluate(self, epoch=0, step=0):
        """
        Given an epoch id, evaluate the model at that point
        """
        # load the data

        model_inner = self.build_model()
        model_outer = self.model_outer
        signature_model = self.signature_model

        # loss_mean, loss_conf, acc = eval(
        #     self.config,
        #     copy.copy(model_outer),
        #     copy.copy(signature_model),
        #     task_family=self.task_family_train,
        #     num_updates=self.config.model.num_inner_updates,
        # )

        # metric_dict = {
        #     "loss_mean": loss_mean,
        #     "loss_conf": loss_conf,
        #     "mode": "train",
        #     "steps": epoch,
        #     "epoch_idx": epoch,
        #     "accuracy": acc,
        #     "minibatch": step,
        # }
        # self.logbook.write_metric_logs(metric_dict)
        # step += 1

        # evaluate on validation set
        # zero shot
        loss_mean, loss_conf, acc = eval(
            self.config,
            copy.copy(model_outer),
            copy.copy(signature_model),
            task_family=self.task_family_valid,
            num_updates=0,
        )
        metric_dict = {
            "loss_mean_0": loss_mean,
            "loss_conf_0": loss_conf,
            "mode": "valid",
            "steps": epoch,
            "epoch_idx": epoch,
            "accuracy_0": acc,
            "minibatch": step,
        }
        self.logbook.write_metric_logs(metric_dict)
        step += 1

        # k-shot
        loss_mean, loss_conf, acc = eval(
            self.config,
            copy.copy(model_outer),
            copy.copy(signature_model),
            task_family=self.task_family_valid,
            num_updates=self.config.model.num_inner_updates,
        )
        metric_dict = {
            "loss_mean_k": loss_mean,
            "loss_conf_k": loss_conf,
            "mode": "valid",
            "steps": epoch,
            "epoch_idx": epoch,
            "accuracy_k": acc,
            "minibatch": step,
        }
        self.logbook.write_metric_logs(metric_dict)
        step += 1

        # 100-shot
        loss_mean, loss_conf, acc = eval(
            self.config,
            copy.copy(model_outer),
            copy.copy(signature_model),
            task_family=self.task_family_valid,
            num_updates=100,
        )
        metric_dict = {
            "loss_mean_100": loss_mean,
            "loss_conf_100": loss_conf,
            "mode": "valid",
            "steps": epoch,
            "epoch_idx": epoch,
            "accuracy_100": acc,
            "minibatch": step,
        }
        self.logbook.write_metric_logs(metric_dict)
        step += 1

        # evaluate on test set
        # zero shot
        loss_mean, loss_conf, acc = eval(
            self.config,
            copy.copy(model_outer),
            copy.copy(signature_model),
            task_family=self.task_family_test,
            num_updates=0,
        )

        metric_dict = {
            "loss_mean_0": loss_mean,
            "loss_conf_0": loss_conf,
            "mode": "test",
            "steps": epoch,
            "epoch_idx": epoch,
            "accuracy_0": acc,
            "minibatch": step,
        }
        self.logbook.write_metric_logs(metric_dict)
        step += 1

        # k-shot
        loss_mean, loss_conf, acc = eval(
            self.config,
            copy.copy(model_outer),
            copy.copy(signature_model),
            task_family=self.task_family_test,
            num_updates=self.config.model.num_inner_updates,
        )

        metric_dict = {
            "loss_mean_k": loss_mean,
            "loss_conf_k": loss_conf,
            "mode": "test",
            "steps": epoch,
            "epoch_idx": epoch,
            "accuracy_k": acc,
            "minibatch": step,
        }
        self.logbook.write_metric_logs(metric_dict)
        step += 1

        # 100 shot
        loss_mean, loss_conf, acc = eval(
            self.config,
            copy.copy(model_outer),
            copy.copy(signature_model),
            task_family=self.task_family_test,
            num_updates=100,
        )

        metric_dict = {
            "loss_mean_100": loss_mean,
            "loss_conf_100": loss_conf,
            "mode": "test",
            "steps": epoch,
            "epoch_idx": epoch,
            "accuracy_100": acc,
            "minibatch": step,
        }
        self.logbook.write_metric_logs(metric_dict)
        step += 1
        return step


def make_signature_function(config: Box) -> nn.Module:
    signature_model = NodeGatEncoder(config)
    signature_model.model_config_key = "signature_gat"
    return signature_model


def eval(
    config,
    model,
    signature_model,
    task_family,
    num_updates,
    n_tasks=100,
    return_gradnorm=False,
):
    # copy weights of network
    copy_weights = [w.clone() for w in model.weights]
    signature_weights = [w.clone() for w in signature_model.weights]

    # logging
    losses = []
    acc = []
    gradnorms = []

    # --- inner loop ---
    n_tasks = len(task_family.graphworld_list)
    pb_t = tqdm(total=n_tasks)
    for task_id in range(n_tasks):

        # reset network weights
        model.set_weights([w.clone() for w in copy_weights])
        signature_model.set_weights([w.clone() for w in signature_weights])

        # ---- K shot adaptation on the train split -----

        graphs, queries, targets, target_function, world_graph = get_input_and_target(
            config=config,
            batch_size=config.general.batch_size,  # 4,
            task_family=task_family,
            target_function=None,
            mode="train",
            task_id=task_id,
        )
        bt = 0
        batch = GraphBatch(
            graphs=graphs, queries=queries, targets=targets, world_graphs=[world_graph],
        )

        batch.to(config.general.device)
        # ------------ update on current task ------------
        ut = 0
        for _ in range(1, num_updates + 1):
            rel_emb, _ = signature_model(batch)
            outputs = model(batch, rel_emb)

            # compute loss for current task
            task_loss = F.cross_entropy(outputs, batch.targets)

            # update task parameters
            grads = torch.autograd.grad(
                task_loss, model.weights + signature_model.weights
            )
            grads = [
                torch.clamp(g, -config.model.clamp, config.model.clamp) for g in grads
            ]
            if torch.isnan(task_loss):
                set_trace()
            gradnorms.append(np.mean(np.array([g.norm().item() for g in grads])))

            for i in range(len(model.weights)):
                model.weights[i] = (
                    model.weights[i] - config.model.lr_inner * grads[i].detach()
                )
            pb_t.set_description("U : {}, Loss : {}".format(ut, task_loss.item()))
            ut += 1

        # ------------ Inference on the test split ------------
        # get the task family (with infinite number of tasks)
        input_range = task_family.get_input_range(mode="test")
        # compute true loss on entire input range
        graphs, queries, indices = input_range
        task_losses = []
        task_acc = []
        pb_inf = tqdm(total=len(graphs))
        for bi in range(0, len(graphs), config.general.batch_size):
            b_graphs = graphs[bi : bi + config.general.batch_size]
            b_queries = queries[bi : bi + config.general.batch_size]
            b_indices = indices[bi : bi + config.general.batch_size]
            batch = GraphBatch(
                b_graphs,
                b_queries,
                target_function(b_indices, "test"),
                world_graphs=[world_graph],
            )
            batch.to(config.general.device)
            rel_emb, _ = signature_model(batch)
            logits = model(batch, rel_emb)
            # set_trace()
            loss = F.cross_entropy(logits, batch.targets)
            if torch.isnan(loss):
                set_trace()

            task_losses.append(loss.detach().item())
            if task_losses[-1] is None:
                set_trace()

            # calculate accuracy
            predictions, conf = model.predict(logits)
            task_acc.append(
                model.accuracy(predictions, batch.targets).cpu().detach().item()
            )
            pb_inf.update(config.general.batch_size)
        pb_inf.close()
        task_acc = np.mean(task_acc)
        acc.append(task_acc)
        task_loss = np.mean(task_losses)
        losses.append(task_loss)
        pb_t.set_description(
            "Task Accuracy : {}, Task Loss : {}".format(task_acc, task_loss)
        )
        pb_t.update(1)
        # only testing one world now
        break

    # reset network weights
    model.weights = [w.clone() for w in copy_weights]

    losses_mean = np.mean(losses)
    losses_conf = st.t.interval(
        0.95, len(losses) - 1, loc=losses_mean, scale=st.sem(losses)
    )
    acc = np.mean(acc)

    if not return_gradnorm:
        return losses_mean, np.mean(np.abs(losses_conf - losses_mean)), acc
    else:
        return (
            losses_mean,
            np.mean(np.abs(losses_conf - losses_mean)),
            np.mean(gradnorms),
            acc,
        )
