import os
import json
import time
from distutils.dir_util import copy_tree
from pytorch_lightning.callbacks.base import Callback

from spaghettini import quick_register

from src.utils.saving_loading import get_most_recent_checkpoint_filepath
from src.utils.misc import getnow

BEST_CHECKPOINT_NAME = "checkpoint_best_val_error"
BEST_CHECKPOINT_EXTENSION = "bestckpt"
METRICS_SUMMARY_NAME = "metrics_summary.json"
BEST_IN_DIST_VAL_ERROR_KEY = "validation/average_in_distribution_val_error"
BIG = int(1e20)


@quick_register
class SimpleCheckpointer(Callback):
    def __init__(self, tmp_ckpt_dir, final_ckpt_dir=None, dirs_dict=None, save_every_n_epoch=BIG, save_every_t_min=BIG,
                 save_at_epochs=list(), save_at_steps=list(), best_checkpoint_determining_metric_key=BEST_IN_DIST_VAL_ERROR_KEY,
                 save_initialization=False, keep_all_checkpoints=False):
        super().__init__()
        self.tmp_ckpt_dir = tmp_ckpt_dir if tmp_ckpt_dir is not None else final_ckpt_dir
        self.final_ckpt_dir = final_ckpt_dir
        self.dirs_dict = dirs_dict

        self.save_every_n_epoch = save_every_n_epoch
        self.save_every_t_min = save_every_t_min
        self.save_at_epochs = save_at_epochs
        self.save_at_steps = save_at_steps

        self.best_checkpoint_determining_metric_key = best_checkpoint_determining_metric_key

        self.save_initialization = save_initialization
        self.keep_all_checkpoints = keep_all_checkpoints

        # Keep track of at which steps and epochs checkpoints are saved, so we don't unnecessarily save checkpoints.
        self.checkpointing_status = dict(last_epoch=-1, last_step=-1, last_time=time.time())

        # ____ Keep track of the best in-distribution validation error computed so far. ____
        # Find the latest best checkpoint.
        best_ckpt_found_path = get_most_recent_checkpoint_filepath(dirs_dict, checkpoint_type="best")

        if best_ckpt_found_path is not None:
            best_ckpt_found_path = os.path.split(best_ckpt_found_path)[0]
            print(f"Previous best checkpoint found at path: {best_ckpt_found_path}")

            best_val_err_json_path = os.path.join(best_ckpt_found_path, METRICS_SUMMARY_NAME)
            with open(best_val_err_json_path, "r") as f:
                json_dict = json.load(f)
            best_val_err = json_dict[best_checkpoint_determining_metric_key]
            self.best_in_dist_val_error_so_far = best_val_err
            print(f"The best in-distribution validation accuracy of the previous best checkpoint is: {best_val_err}")

            # Copy the previous best-performing checkpoint directory into the current temporary directory.
            best_ckpt_dir_tmp, _ = self._create_best_checkpoint_dirs()
            copy_tree(src=best_ckpt_found_path, dst=best_ckpt_dir_tmp)
            print(f"Copied the previous best ckpt from {best_ckpt_found_path} to {best_ckpt_dir_tmp}")
        else:
            print(f"No previous best checkpoint found. ")
            self.best_in_dist_val_error_so_far = float('inf')

    @staticmethod
    def _construct_best_ckpt_dir(path):
        return os.path.join(path, "best_checkpoint_dir")

    def _save_checkpoints(self, trainer, checkpoint_name="checkpoint.ckpt"):
        if self.keep_all_checkpoints:
            components = checkpoint_name.split(".")
            curr_epoch, curr_step = trainer.current_epoch, trainer.global_step
            checkpoint_name = ".".join(components[:-1]) + f"__epoch_{curr_epoch:05d}__step_{curr_step:07d}__" + str(getnow())
            checkpoint_name = checkpoint_name + "." + components[-1]

        print(
            f"\n({self.__class__.__name__}) Saving checkpoint at step {trainer.global_step}. Checkpoint name: {checkpoint_name}\n")
        trainer.save_checkpoint(os.path.join(self.tmp_ckpt_dir, checkpoint_name))
        trainer.save_checkpoint(os.path.join(self.final_ckpt_dir, checkpoint_name))

        # Update checkpointing status.
        self.checkpointing_status = dict(last_epoch=trainer.current_epoch, last_step=trainer.global_step,
                                         last_time=time.time())

    def best_checkpoint_saving_routine(self, trainer, pl_module):
        current_in_dist_val_error = trainer.logger.experiment.summary[self.best_checkpoint_determining_metric_key]
        if current_in_dist_val_error < self.best_in_dist_val_error_so_far:
            # Update the best validation erorr statistic.
            self.best_in_dist_val_error_so_far = current_in_dist_val_error

            # Create the directory in which the both the best checkpoint so far, and the value of the best
            # in distribution validation error will be saved.
            best_ckpt_dirs = self._create_best_checkpoint_dirs()

            # Save the checkpoint, as well as the value of the best validation error.
            for save_dir in best_ckpt_dirs:
                # Save the checkpoint.
                checkpoint_name = f"{BEST_CHECKPOINT_NAME}.{BEST_CHECKPOINT_EXTENSION}"
                checkpoint_path = os.path.join(save_dir, checkpoint_name)
                trainer.save_checkpoint(checkpoint_path)

                # Save a text file that contains the value of the best validation error obtained so far.
                best_val_error_value_filename = os.path.join(save_dir, METRICS_SUMMARY_NAME)
                latest_logged_metrics = dict(trainer.logger.experiment.summary)
                latest_logged_metrics = {k: v for k, v in latest_logged_metrics.items() if type(v) in [int, float, str]}
                save_json = json.dumps(latest_logged_metrics, indent=4)
                with open(best_val_error_value_filename, "w") as f:
                    f.write(save_json)

    def checkpointing_routine(self, trainer, pl_module):
        mins_since_last_saved = (time.time() - self.checkpointing_status["last_time"]) / 60

        # Check if the checkpoint saving conditions are satisfied.
        conditions = list()
        conditions.append(trainer.current_epoch % self.save_every_n_epoch == 0 and
                          trainer.current_epoch != self.checkpointing_status["last_epoch"])  # Epoch check.
        conditions.append(mins_since_last_saved > self.save_every_t_min)  # Time check.
        conditions.append(trainer.current_epoch in self.save_at_epochs and
                          trainer.current_epoch != self.checkpointing_status["last_epoch"])  # Save at epoch check.
        conditions.append(trainer.global_step in self.save_at_steps and
                          trainer.global_step != self.checkpointing_status["last_step"])  # Save at step check.

        if any(conditions):
            self._save_checkpoints(trainer=trainer)

        # Save another checkpoint if the current one achieves the best in distribution validation error.
        if self.best_checkpoint_determining_metric_key not in trainer.logger.experiment.summary.keys():
            message = f"Training loop must log the key {self.best_checkpoint_determining_metric_key} " \
                      f"for this key to be used for best checkpoint determining. "
            raise AttributeError(message)
        self.best_checkpoint_saving_routine(trainer=trainer, pl_module=pl_module)

    def on_fit_start(self, trainer, pl_module):
        if self.save_initialization:
            self._save_checkpoints(trainer=trainer, checkpoint_name="at_initialization_checkpoint.ckpt")

    def on_train_epoch_end(self, trainer, pl_module):
        self.checkpointing_routine(trainer=trainer, pl_module=pl_module)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, unused=0):
        self.checkpointing_routine(trainer=trainer, pl_module=pl_module)

    def on_train_end(self, trainer, pl_module):
        print(f"Reached the end of training. Saving the final checkpoint. ")
        self._save_checkpoints(trainer=trainer)

    def _create_best_checkpoint_dirs(self):
        best_ckpt_dir_tmp = self._construct_best_ckpt_dir(path=self.tmp_ckpt_dir)
        best_ckpt_dir_final = self._construct_best_ckpt_dir(self.final_ckpt_dir)
        os.makedirs(best_ckpt_dir_tmp, exist_ok=True)
        os.makedirs(best_ckpt_dir_final, exist_ok=True)

        return [best_ckpt_dir_tmp, best_ckpt_dir_final]
