"""Class to run the experiments"""
# from time import time

import numpy as np
import torch

from codes.logbook.logbook import LogBook
from codes.utils.meta_data import MetaDataset
import copy


class MAMLExperiment:
    """MAML style training Experiment Class"""

    def __init__(self, config, model, data):
        self.config = config
        self.logbook = LogBook(self.config)
        self.support_modes = self.config.model.modes
        self.device = self.config.general.device
        self.model = model
        if self.config.logger.watch_model:
            self.logbook.watch_model(model=self.model)
        self._mode = None
        self.data = data
        self.train_task = MetaDataset(config, mode="train")
        self.valid_task = MetaDataset(config, mode="train")
        self.test_task = MetaDataset(config, mode="train")
        self.reset_experiment()
        self.train_step = 0
        # torch.autograd.set_detect_anomaly(mode=True)

    def reset_experiment(self):
        """Reset the experiment"""
        self._mode = None

    def setup_maml(self, model):
        """
        Setup the inner and outer models here
        Setup optimizers here
        :param model:
        :return:
        """
        self.model_inner = model
        self.model_outer = copy.deepcopy(model)
        self.meta_optimizer = self.model_outer.get_optimizers()[0]

    def run(self):
        """Method to run the experiment"""
        if self.config.model.should_load_model:
            self.model.load_model()
        epoch = 0
        if self.config.model.should_train:
            self.model.train()
            for epoch in range(self.config.model.num_epochs):
                self.train(epoch)
                self.save(epochs=epoch)
                self.evaluate(epoch=epoch)

        # now compute test scores
        self.evaluate(epoch=epoch, mode="test")

    def train(self, epoch=0):
        """
        Method to train
        :return:
        """
        mode = "train"
        epoch_loss = []
        epoch_acc = []
        self.model.train()

        for batch_idx, batch in enumerate(self.dataloaders[mode]):
            graphs, queries, labels = batch
            logits = self.model(graphs, queries)
            labels = labels.long().squeeze(1).to(self.config.general.device)
            loss = self.model.loss(logits, labels)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_loss.append(loss.cpu().detach().item())
            predictions, conf = self.model.predict(logits)
            epoch_acc.append(
                self.model.accuracy(predictions, labels).cpu().detach().item()
            )
            is_last = batch_idx + 1 >= len(self.dataloaders[mode])
            if batch_idx % self.config.logger.remote.frequency == 0 or is_last:
                metrics = {
                    "mode": "train",
                    "minibatch": self.train_step,
                    "loss": np.mean(epoch_loss),
                    "loss_std": np.std(epoch_loss),
                    "accuracy": np.mean(epoch_acc),
                    "epoch": epoch,
                }
                self.logbook.write_metric_logs(metrics)
                epoch_loss = []
                epoch_acc = []
                self.train_step += 1

    def evaluate(self, epoch=0, mode="valid"):
        """Method to run the evaluation"""

        with torch.no_grad():
            self.model.eval()
            epoch_loss = []
            epoch_acc = []
            for batch_idx, batch in enumerate(self.dataloaders[mode]):
                graphs, queries, labels = batch
                labels = labels.long().squeeze(1).to(self.config.general.device)
                logits = self.model(graphs, queries)
                loss = self.model.loss(logits, labels)
                predictions, conf = self.model.predict(logits)
                epoch_loss.append(loss.cpu().detach().item())
                epoch_acc.append(
                    self.model.accuracy(predictions, labels).cpu().detach().item()
                )

            metrics = {
                "mode": mode,
                "minibatch": self.train_step,
                "epoch": epoch,
                "accuracy": np.mean(epoch_acc),
                "accuracy_std": np.std(epoch_acc),
                "loss": np.mean(epoch_loss),
            }

            self.logbook.write_metric_logs(metrics)

    def save(self, epochs):
        """Method to save the experiment"""
        if (
            self.config.model.persist_frquency > 0
            and epochs % self.config.model.persist_frquency == 0
        ):
            self.model.save_model(epochs)


def prepare_and_run_experiment(config, model, data):
    """Primary method to interact with the Experiments"""
    experiment = Experiment(config, model, data)
    experiment.run()
