"""Class to run the experiments"""
# from time import time

import numpy as np
import torch
import copy
import os

from codes.utils.checkpointable import Checkpointable
from typing import Optional, Iterable


class CheckpointableExperiment(Checkpointable):
    """Checkpointable Experiment Class"""

    def __init__(self, config, model, data, logbook):
        self.config = config
        self.logbook = logbook
        self.support_modes = self.config.model.modes
        self.device = self.config.general.device
        self.model = model
        if self.config.logger.watch_model:
            self.logbook.watch_model(model=self.model)
        self._mode = None
        self.data = data
        modes = ["train", "valid", "test"]
        self.dataloaders = {}
        self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()
        self.epoch = None
        # Epoch to be used when checkpointing the model
        for mi, mode in enumerate(modes):
            self.dataloaders[mode] = {}
            for graph_world in self.data[mi]:
                rule_world = graph_world.rule_world
                self.dataloaders[mode][rule_world] = graph_world.get_dataloaders(
                    modes=["train", "valid", "test"]
                )

        # self.dataloaders = self.data.get_dataloaders(
        #     modes=["train", "valid", "test"]
        # )
        self.reset_experiment()
        self.train_step = 0
        # torch.autograd.set_detect_anomaly(mode=True)

    def reset_experiment(self):
        """Reset the experiment"""
        self._mode = None
        self.epoch = 0

    def get_scheduler_and_optimizer(self):
        optimizers, schedulers = self.model.get_optimizers_and_schedulers()
        return optimizers[0], schedulers[0]

    def run(self):
        """Method to run the experiment"""
        if self.config.model.should_load_model:
            self.model.load_model()
        current_epoch = 0
        epoch_to_start_from = self.epoch
        if self.config.model.should_train:
            self.model.train()
            for train_rule_world in self.dataloaders["train"]:
                train_data = self.dataloaders["train"][train_rule_world]
                self.logbook.write_message_logs(f"Training rule " f"{train_rule_world}")
                self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()
                for current_epoch in range(
                    epoch_to_start_from, self.config.model.num_epochs
                ):
                    self.train(train_data, train_rule_world, current_epoch)
                    self.epoch = current_epoch
                    self.periodic_save(epoch=current_epoch)
                    metrics = self.eval(
                        self.dataloaders["train"],
                        epoch=current_epoch,
                        mode="valid",
                        data_mode="train",
                    )
                    self.scheduler.step(metrics["loss"])
                    self.eval(
                        self.dataloaders["valid"],
                        epoch=current_epoch,
                        mode="test",
                        data_mode="valid",
                    )
                    self.eval(
                        self.dataloaders["test"],
                        epoch=current_epoch,
                        mode="test",
                        data_mode="test",
                    )
        # now compute test scores
        # self.evaluate(epoch=epoch, mode='test', top_mode='train')
        # self.evaluate(epoch=epoch, mode='test', top_mode='valid')
        # self.evaluate(epoch=epoch, mode='test', top_mode='test')

    def run_supervised(self):
        """Independently run models on datasets"""
        if self.config.model.should_load_model:
            self.model.load_model()
        starting_model = copy.deepcopy(self.model)

        if self.config.model.should_train:
            for task_idx, train_rule_world in enumerate(self.dataloaders["train"]):
                train_data = self.dataloaders["train"][train_rule_world]
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Starting to train the model " f"on {train_rule_world}"
                )
                self.model = copy.deepcopy(starting_model)
                self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()
                for epoch in range(self.config.model.num_epochs):
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    # self.periodic_save(epochs=epoch)
                    metrics = self.eval(
                        {train_rule_world: self.dataloaders["train"][train_rule_world]},
                        epoch=epoch,
                        mode="valid",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    self.scheduler.step(metrics["loss"])

    def run_sequential_multitask_training(self):
        """supervised case I: train one model on all the tasks sequentially
         This is the sequential multitask setting
         TRAIN on train worlds, sequentially
         EVALUATE on train worlds
         """
        if self.config.model.should_load_model:
            self.model.load_model()
        self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()
        if self.config.model.should_train:
            for task_idx, train_rule_world in enumerate(self.dataloaders["train"]):
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Starting to train the model " f"on {train_rule_world}"
                )
                for epoch in range(self.config.model.num_epochs):
                    train_data = self.dataloaders["train"][train_rule_world]
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    # self.periodic_save(epochs=epoch)
                    metrics = self.eval(
                        {train_rule_world: self.dataloaders["train"][train_rule_world]},
                        epoch=epoch,
                        mode="valid",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    self.scheduler.step(metrics["loss"])
                    self.eval(
                        {train_rule_world: self.dataloaders["train"][train_rule_world]},
                        epoch=epoch,
                        mode="test",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    # self.eval(self.dataloaders['test'], epoch=epoch,
                    #           mode='test', data_mode='test')
                self.optimizer = self.model.reset_optim_lr(self.optimizer)

    def run_sequential_zeroshot_transfer(self):
        """supervised case II: train one model for the first task and transfer
         that mopdel on all the subsequent tasks
         TRAIN on the train worlds, sequentially
         EVALUATE on the test worlds
         """
        if self.config.model.should_load_model:
            self.model.load_model()
        self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()

        if self.config.model.should_train:
            for task_idx, train_rule_world in enumerate(self.dataloaders["train"]):
                train_data = self.dataloaders["train"][train_rule_world]
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Starting to train the model " f"on {train_rule_world}"
                )
                for epoch in range(self.config.model.num_epochs):
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    # self.periodic_save(epochs=epoch)
                    metrics = self.eval(
                        {train_rule_world: self.dataloaders["train"][train_rule_world]},
                        epoch=epoch,
                        mode="valid",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    self.scheduler.step(metrics["loss"])
                    self.eval(
                        self.dataloaders["test"],
                        epoch=epoch,
                        mode="test",
                        data_mode="test",
                        task_idx=task_idx,
                        skip_world=train_rule_world,
                    )

    def run_sequential_fewshot_transfer(self, full_shot=False):
        """supervised case III: train one model for the a task and
        adapt that model for individual tasks
        Apples to Apples comparison with MAML"""
        if self.config.model.should_load_model:
            self.model.load_model()

        if self.config.model.should_train:
            for task_idx, train_rule_world in enumerate(self.dataloaders["train"]):
                train_data = self.dataloaders["train"][train_rule_world]
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Starting to train the model " f"on {train_rule_world}"
                )
                self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()
                for epoch in range(self.config.model.num_epochs):
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    # self.periodic_save(epochs=epoch)
                    metrics = self.eval(
                        {train_rule_world: self.dataloaders["train"][train_rule_world]},
                        epoch=epoch,
                        mode="valid",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    self.scheduler.step(metrics["loss"])
                model = copy.deepcopy(self.model)
                all_metrics = {
                    "loss": [],
                    "accuracy": [],
                }
                for task_idx, test_rule_world in enumerate(self.dataloaders["test"]):
                    self.model = copy.deepcopy(model)
                    self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()
                    test_data = self.dataloaders["test"][test_rule_world]
                    train_per = 1.0 / self.config.general.batch_size
                    if full_shot:
                        train_per = 1.0
                        for ep in range(self.config.model.num_epochs):
                            self.train(
                                test_data,
                                test_rule_world,
                                epoch,
                                task_idx=task_idx,
                                train_per=train_per,
                            )
                    else:
                        self.train(
                            test_data,
                            test_rule_world,
                            epoch,
                            task_idx=task_idx,
                            train_per=train_per,
                        )
                    metrics = self.evaluate(
                        test_data,
                        test_rule_world,
                        epoch=epoch,
                        mode="test",
                        report=False,
                    )
                    all_metrics["loss"].append(metrics["loss"])
                    all_metrics["accuracy"].append(metrics["accuracy"])
                self.model = copy.deepcopy(model)
                all_metrics = {
                    "mode": "test_test",
                    "minibatch": self.train_step,
                    "loss": np.mean(all_metrics["loss"]),
                    "accuracy": np.mean(all_metrics["accuracy"]),
                }
                self.logbook.write_metric_logs(all_metrics)

    def run_transfer_pretrained_zeroshot_transfer(self):
        """supervised case III: train one model for all the train tasks and transfer
         that model on all the test tasks in zero shot"""
        if self.config.model.should_load_model:
            self.model.load_model()
        self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()

        if self.config.model.should_train:
            for task_idx, train_rule_world in enumerate(self.dataloaders["train"]):
                train_data = self.dataloaders["train"][train_rule_world]
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Starting to train the model " f"on {train_rule_world}"
                )
                for epoch in range(self.config.model.num_epochs):
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    # self.periodic_save(epochs=epoch)
                    metrics = self.eval(
                        self.dataloaders["train"],
                        epoch=epoch,
                        mode="valid",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    self.scheduler.step(metrics["loss"])
                self.eval(
                    self.dataloaders["test"],
                    epoch=self.config.model.num_epochs,
                    mode="valid",
                    data_mode="test",
                    task_idx=task_idx,
                )
                self.eval(
                    self.dataloaders["test"],
                    epoch=self.config.model.num_epochs,
                    mode="test",
                    data_mode="test",
                    task_idx=task_idx,
                )
                self.optimizer = self.model.reset_optim_lr(self.optimizer)

            # for task_idx, current_rule_world in enumerate(
            #         self.dataloaders['test']):
            #     if train_rule_world == current_rule_world:
            #         continue
            #     self.logbook.write_message_logs(f"Transferring rule {current_rule_world}")
            #     self.logbook.write_message_logs(f"Starting to trasnfer the model "
            #                                     f"on {current_rule_world}")
            #     self.eval(self.dataloaders['test'], epoch=self.config.model.num_epochs,
            #                         mode='valid', data_mode='test',
            #                         task_idx=task_idx)
            #     self.eval(self.dataloaders['test'], epoch=self.config.model.num_epochs,
            #               mode='test', data_mode='test', task_idx=task_idx)

    def run_transfer_pretrained_sequential_multitask_transfer(self):
        """supervised case IV:
            - train one model for all the train tasks
            - train the model to the new test tasks sequentially
         that model on all the test tasks in zero shot"""
        if self.config.model.should_load_model:
            self.model.load_model()
        self.optimizer, self.scheduler = self.get_scheduler_and_optimizer()

        if self.config.model.should_train:
            for task_idx, train_rule_world in enumerate(self.dataloaders["train"]):
                train_data = self.dataloaders["train"][train_rule_world]
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Starting to train the model " f"on {train_rule_world}"
                )
                for epoch in range(self.config.model.num_epochs):
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    # self.periodic_save(epochs=epoch)
                    metrics = self.eval(
                        self.dataloaders["train"],
                        epoch=epoch,
                        mode="valid",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    self.scheduler.step(metrics["loss"])
                self.optimizer = self.model.reset_optim_lr(self.optimizer)
            # now perform sequential testing
            for task_idx, train_rule_world in enumerate(self.dataloaders["test"]):
                train_data = self.dataloaders["test"][train_rule_world]
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Starting to train the model " f"on {train_rule_world}"
                )
                for epoch in range(self.config.model.num_epochs):
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    # self.periodic_save(epochs=epoch)
                    metrics = self.eval(
                        self.dataloaders["test"],
                        epoch=epoch,
                        mode="valid",
                        data_mode="test",
                        task_idx=task_idx,
                    )
                    self.scheduler.step(metrics["loss"])
                    self.eval(
                        self.dataloaders["test"],
                        epoch=epoch,
                        mode="test",
                        data_mode="test",
                        task_idx=task_idx,
                    )
                    # self.eval(self.dataloaders['test'], epoch=epoch,
                    #           mode='test', data_mode='test')
                self.optimizer = self.model.reset_optim_lr(self.optimizer)

    def run_random_model(self):
        if self.config.model.should_load_model:
            self.model.load_model()
        if self.config.model.should_train:
            for task_idx, train_rule_world in enumerate(self.dataloaders["test"]):
                self.eval(
                    {train_rule_world: self.dataloaders["test"][train_rule_world]},
                    epoch=1,
                    mode="test",
                    data_mode="test",
                    task_idx=task_idx,
                )

    def eval(
        self,
        data,
        epoch,
        mode="valid",
        data_mode="train",
        task_idx=None,
        skip_world=None,
        report=True,
    ):

        all_metrics = {
            "loss": [],
            "accuracy": [],
        }
        for val_rule_world, valid_data in data.items():
            if skip_world:
                if skip_world == val_rule_world:
                    continue
            # train few shot
            # self.train(valid_data, val_rule_world, epoch=epoch, report=False)
            metrics = self.evaluate(
                valid_data, val_rule_world, epoch=epoch, mode=mode, report=False
            )
            all_metrics["loss"].append(metrics["loss"])
            all_metrics["accuracy"].append(metrics["accuracy"])
        all_metrics = {
            "mode": "{}_{}".format(data_mode, mode),
            "minibatch": self.train_step,
            "loss": np.mean(all_metrics["loss"]),
            "accuracy": np.mean(all_metrics["accuracy"]),
        }
        if task_idx:
            all_metrics["task_idx"] = task_idx
        # self.train_step += 1
        self.logbook.write_metric_logs(all_metrics)
        return all_metrics

    def train(
        self, data, rule_world, epoch=0, train_per=1.0, report=True, task_idx=None
    ):
        """
        Method to train
        :return:
        """
        mode = "train"
        epoch_loss = []
        epoch_acc = []
        self.model.train()
        num_batches = len(data[mode])
        num_batches_to_train = num_batches * train_per

        for batch_idx, batch in enumerate(data[mode]):
            if batch_idx >= num_batches_to_train:
                continue
            batch.to(self.config.general.device)
            logits = self.model(batch)
            loss = self.model.loss(logits, batch.targets)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_loss.append(loss.cpu().detach().item())
            predictions, conf = self.model.predict(logits)
            epoch_acc.append(
                self.model.accuracy(predictions, batch.targets).cpu().detach().item()
            )
            if report:
                is_last = batch_idx + 1 >= len(data[mode])
                if batch_idx % self.config.logger.remote.frequency == 0 or is_last:
                    rule_world_last = rule_world.split("/")[-1]
                    metrics = {
                        "mode": mode,
                        "minibatch": self.train_step,
                        "loss": np.mean(epoch_loss),
                        "accuracy": np.mean(epoch_acc),
                        "epoch": epoch,
                        "rule_world": rule_world,
                    }
                    if task_idx:
                        metrics["task_idx"] = task_idx
                    self.logbook.write_metric_logs(metrics)
                    epoch_loss = []
                    epoch_acc = []
            self.train_step += 1

    @torch.no_grad()
    def evaluate(
        self, data, rule_world, epoch=0, mode="valid", top_mode="train", report=True
    ):
        """Method to run the evaluation"""
        assert mode != "train"
        self.model.eval()
        epoch_loss = []
        epoch_acc = []
        for batch_idx, batch in enumerate(data[mode]):
            batch.to(self.config.general.device)
            logits = self.model(batch)
            loss = self.model.loss(logits, batch.targets)
            predictions, conf = self.model.predict(logits)
            epoch_loss.append(loss.cpu().detach().item())
            epoch_acc.append(
                self.model.accuracy(predictions, batch.targets).cpu().detach().item()
            )

        rule_world_last = rule_world.split("/")[-1]
        metrics = {
            "mode": mode,
            "minibatch": self.train_step,
            "epoch": epoch,
            "accuracy": np.mean(epoch_acc),
            "loss": np.mean(epoch_loss),
            "top_mode": top_mode,
            "rule_world": rule_world,
        }
        if report:
            # self.train_step += 1
            self.logbook.write_metric_logs(metrics)
        return metrics

    def save(self, epoch: Optional[int]) -> None:
        """Method to save the experiment"""
        if epoch is None:
            epoch = self.epoch
            self.logbook.write_message_logs("Save the epoch: " + str(epoch))
        self.model.save(epoch, optimizers=[self.optimizer])

    def load(self, epoch: Optional[int]) -> None:
        """Method to load the entire experiment"""
        path_to_load_epoch_state = os.path.join(self.config.model.save_dir, "epoch.tar")
        if not os.path.exists(path_to_load_epoch_state):
            ## New experiment. Nothing to load
            self.epoch = 0
            self.logbook.write_message_logs("No model to load")
            return
        epoch_state = None
        if epoch is None:
            epoch_state = torch.load(path_to_load_epoch_state)
            epoch = epoch_state["current"]
        self.logbook.write_message_logs(
            "Found model for epoch {} to load".format(epoch)
        )
        self.epoch = epoch
        self._load_model()

    def _load_model(self, epoch=None) -> None:
        """Internal method to load the model (only)"""
        optimizers, schedulers = self.model.load(
            epoch=self.epoch, optimizers=[self.optimizer], schedulers=[self.scheduler]
        )
        self.optimizer = optimizers[0]
        self.scheduler = schedulers[0]

    def periodic_save(self, epoch):
        """Method to perioridically save the experiment.
        This method is a utility method, built on top of save method.
        It performs an extra check of wether the experiment is configured to
        be saved during the current epoch."""
        if (
            self.config.model.persist_frquency > 0
            and epoch % self.config.model.persist_frquency == 0
        ):
            self.save(epoch)


def prepare_and_run_experiment(config, model, data, logbook):
    """Primary method to interact with the Experiments"""
    experiment = CheckpointableExperiment(config, model, data, logbook)
    # experiment.run()
    if config.general.train_mode == "supervised":
        experiment.run_supervised()
    elif config.general.train_mode == "seq_mult":
        experiment.run_sequential_multitask_training()
    elif config.general.train_mode == "seq_zero":
        experiment.run_sequential_zeroshot_transfer()
    elif config.general.train_mode == "seq_full":
        experiment.run_sequential_fewshot_transfer(full_shot=True)
    elif config.general.train_mode == "seq_few":
        experiment.run_sequential_fewshot_transfer()
    else:
        raise NotImplementedError(
            "training mode not implemented. should be either one of \n supervised / seq_mult / seq_zero / seq_full / seq_few"
        )
    # experiment.run_transfer_pretrained_zeroshot_transfer()
    # experiment.run_transfer_pretrained_sequential_multitask_transfer()
    # experiment.run_random_model()
