# -*- coding: utf-8 -*-

import os
import gc
import signal
import h5py
import sys
import torch as to
import matplotlib.pyplot as plt
from abc import ABC
from typing import Union, Iterable, Optional, Dict, Tuple, Callable
from tvem.exp import EEMConfig, FullEMConfig, ExpConfig, Training
from tvem.utils import H5Logger
from tvem.utils.parallel import barrier, gather_from_processes, pprint
from mloutil.viz import matrix_as_image
from mloutil.prepost.overlapping_patches import (
    OverlappingPatches,
    MultiDimOverlappingPatches,
    mean_merger,
    median_merger,
)
from utils import (
    get_merge_epochs,
    get_comm_rank,
    get_log_blacklist,
    stdout_logger,
    StoppingCriterion,
    free_energy_vs_eval_metric_lineplot,
)
from params import defaults


def _close_open_h5_files():
    """Browse through _all_ objects and close those of type h5py.File"""
    for obj in gc.get_objects():
        if isinstance(obj, h5py.File):
            try:
                obj.close()
                print(f"Closed {obj}")
            except Exception:
                pass


class ImageReconstruction(ABC):
    def __init__(
        self,
        data_file: str,
        patches: Optional[Union[OverlappingPatches, MultiDimOverlappingPatches]],
        model,
        estep_conf: Union[EEMConfig, FullEMConfig],
        output_directory: str,
        batch_size: int,
        no_epochs: int,
        merge_every: int,
        eval_metric_fn: Optional[Callable],
        eval_metric_name: str,
        eval_metric_label: str,
        interactive: bool,
        interactive_pause: float,
        reco_logger: Optional[H5Logger],
        stop_if_eval_metric_diff_negative_in_x_of_y_epochs: Tuple[int, int] = None,
        keep_training_data_file: bool = False,
        keep_training_output_file: bool = False,
        keep_reco_file: bool = False,
        log_blacklist: Iterable[str] = (
            "train_lpj",
            "train_states",
            "train_subs",
            "train_reconstruction",
            "THETA",
        ),
        warmup_Esteps: int = 0,
    ):
        if merge_every is not None:
            assert isinstance(
                model, Reconstructor
            ), "Reconstruction requested: model must be instance of Reconstructor"
        assert (
            isinstance(estep_conf, EEMConfig)
            or isinstance(estep_conf, FullEMConfig)
        ), "estep_conf must be one of (EEMConfig, FullEMConfig)"

        self.data_file = data_file
        self.patches = patches
        self.model = model
        self.estep_conf = estep_conf
        self.output_directory = output_directory
        self.batch_size = batch_size
        self.no_epochs = no_epochs
        self.warmup_Esteps = warmup_Esteps
        self.eval_metric_fn = eval_metric_fn
        self.eval_metric_name = eval_metric_name
        self.eval_metric_label = eval_metric_label
        self.stop_if_eval_metric_diff_negative_in_x_of_y_epochs = (
            stop_if_eval_metric_diff_negative_in_x_of_y_epochs
        )
        self.interactive = interactive
        self.interactive_pause = interactive_pause
        self.reco_logger = reco_logger
        self.keep_training_data_file = keep_training_data_file
        self.keep_training_output_file = keep_training_output_file
        self.keep_reco_file = keep_reco_file
        self.log_blacklist = get_log_blacklist(log_blacklist, model.theta)
        self.training_file = f"{output_directory}/training.h5"

        self._set_reco_and_merge_epochs(merge_every)

    def _set_reco_and_merge_epochs(self, merge_every: int):
        self.reco_epochs = to.arange(self.no_epochs)  # M-steps require data w/o missing values
        self.merge_epochs = (
            get_merge_epochs(merge_every, self.no_epochs) if merge_every is not None else None
        )

    def _get_training(self) -> Training:
        exp_config = ExpConfig(
            batch_size=self.batch_size,
            shuffle=defaults.shuffle,
            warmup_Esteps=self.warmup_Esteps,
            output=self.training_file,
            reco_epochs=self.reco_epochs,
            log_only_latest_theta=True,
            log_blacklist=self.log_blacklist,
        )
        return Training(
            conf=exp_config,
            estep_conf=self.estep_conf,
            model=self.model,
            train_data_file=self.data_file,
        )

    def init_viz(self):
        """Initialize visualizations"""
        comm_rank = get_comm_rank()
        if comm_rank != 0:
            return

        # free energy vs. eval_metric
        self._free_energy_xdata = []
        self._free_energy_ydata = []
        self._eval_metric_xdata = []
        self._eval_metric_ydata = []
        self.free_energy_eval_metric_viz = free_energy_vs_eval_metric_lineplot(
            self._free_energy_xdata,
            self._free_energy_ydata,
            self._eval_metric_xdata,
            self._eval_metric_ydata,
            figure_name=f"free_energy_vs_{self.eval_metric_name}",
            eval_metric_ylabel=self.eval_metric_label,
            linestyle_free_energy=defaults.linestyle_free_energy_eval_metric_plot,
            linestyle_eval_metric=defaults.linestyle_free_energy_eval_metric_plot,
            marker_free_energy=defaults.marker_free_energy_eval_metric_plot,
            marker_eval_metric=defaults.marker_free_energy_eval_metric_plot,
            markersize=defaults.markersize_free_energy_eval_metric_plot,
            fontsize=defaults.fontsize_free_energy_eval_metric_plot,
            ticksize=defaults.ticksize_free_energy_eval_metric_plot,
            use_tex_fonts=defaults.tex_fonts,
            dpi=defaults.dpi,
        )

        # prepare visualization of reconstructed image
        placeholder_reco_image = to.zeros(self.patches.get_image_shape())
        placeholder_reco_image[:] = float("nan")
        self.reco_img_viz = matrix_as_image(
            cdata=placeholder_reco_image,
            dpi=defaults.dpi,
            colormap="jet",#"gray" if len(self.patches.get_image_shape()) == 2 else "jet",
            figure_name="placeholder_reco_image_means",
        )

        if self.interactive:
            plt.draw()
            plt.show()
            plt.pause(self.interactive_pause)

    def _merge_patches(self) -> Optional[Dict[str, to.tensor]]:
        trainer = self.training.trainer
        patches = self.patches
        assert hasattr(
            trainer, "train_reconstruction"
        ), "Cannot generate image reconstruction - Trainer has no train_reconstruction available"
        comm_rank = get_comm_rank()
        reconstructed_data_points = gather_from_processes(trainer.train_reconstruction)
        return (
            {
                f"{descr}": patches.set_and_merge(reconstructed_data_points.t().cpu(), merge_method=fn)
                for (descr, fn) in {"means": mean_merger, "medians": median_merger}.items()
            }
            if comm_rank == 0
            else None
        )

    def _compute_eval_metric(self, reco_imgs_dict):
        comm_rank = get_comm_rank()
        if comm_rank != 0:
            return

        assert self.patches.get_image_shape() == reco_imgs_dict["means"].shape
        assert self.patches.get_image_shape() == reco_imgs_dict["medians"].shape

        return {f"{descr}": self.eval_metric_fn(reco_imgs_dict[descr]) for descr in reco_imgs_dict}

    def viz_epoch(
        self,
        free_energy: float,
        reco_imgs_dict: Optional[Dict[str, to.tensor]],
        eval_metric_dict: Optional[Dict[str, to.tensor]],
    ):
        """Visualize epoch"""
        comm_rank = get_comm_rank()
        if comm_rank != 0:
            return
        ind_epoch = len(self._free_energy_ydata)
        self._free_energy_xdata.append(ind_epoch)
        self._free_energy_ydata.append(free_energy)

        if eval_metric_dict is not None:
            assert "means" in eval_metric_dict
            self._eval_metric_xdata.append(ind_epoch)
            self._eval_metric_ydata.append(eval_metric_dict["means"])

            self.free_energy_eval_metric_viz.update(
                self._free_energy_xdata,
                self._free_energy_ydata,
                self._eval_metric_xdata,
                self._eval_metric_ydata,
                figure_name=f"free_energy_vs_{self.eval_metric_name}",
                focus_free_energy_last_x=defaults.focus_free_energy_last_x,
                focus_eval_metric_last_x=defaults.focus_eval_metric_last_x,
                output_directory=self.output_directory,
            )

        # visualize reconstructed image always when available
        if reco_imgs_dict is not None:
            assert "means" in reco_imgs_dict
            self.reco_img_viz.update(
                cdata=reco_imgs_dict["means"],
                figure_name=f"reconstructed_image_means_epoch{ind_epoch}",
                output_directory=self.output_directory,
            )

        if self.interactive:
            plt.draw()
            plt.show()
            plt.pause(self.interactive_pause)

    def log_reco_epoch(
        self,
        reco_imgs_dict: Optional[Dict[str, to.tensor]],
        eval_metric_dict: Optional[Dict[str, to.tensor]],
    ):
        comm_rank = get_comm_rank()
        if comm_rank != 0:
            return
        if reco_imgs_dict is None or eval_metric_dict is None:
            return
        reco_logger = self.reco_logger
        ind_epoch = len(self._free_energy_ydata)

        append_dict = {
            "reco_epochs": to.tensor([ind_epoch]),
            "reconstructed_image_means": reco_imgs_dict["means"],
            "reconstructed_image_medians": reco_imgs_dict["medians"],
            f"{self.eval_metric_name}": to.tensor([eval_metric_dict["means"]]),
        }
        reco_logger.append(**append_dict)
        reco_logger.write()
        for k in append_dict.keys():
            print(f"Appended {k} to {self.reco_logger._fname}")

    def _check_remove_h5_files(self):
        comm_rank = get_comm_rank()
        if comm_rank != 0:
            return
        to_remove = []
        to_remove.append(self.data_file if not self.keep_training_data_file else "")
        to_remove.append(self.training_file if not self.keep_training_output_file else "")
        to_remove.append(self.reco_logger._fname if not self.keep_reco_file else "")
        to_remove = [x for x in to_remove if x]
        for f in to_remove:
            try:
                os.remove(f)
                print(f"Removed {f}")
            except FileNotFoundError:
                pass

    def run(self):
        sys.stdout = stdout_logger(f"{self.output_directory}/terminal.txt")

        # Define how to handle sigterm signal (e.g. on job cancellation) and register signal handler
        def sigterm_handler(signum, frame):
            print("Job interrupted.\nWill close open H5 files now.")
            _close_open_h5_files()
            try:
                self._check_remove_h5_files()
            except Exception:
                pass
            print("Will exit independently now.")
            sys.exit(99)
            sys.exit(0)

        signal.signal(signal.SIGTERM, sigterm_handler)

        self.training = self._get_training()

        if self.interactive:
            plt.ion()
        self.init_viz()
        barrier()

        stopping_criterion = (
            StoppingCriterion(*self.stop_if_eval_metric_diff_negative_in_x_of_y_epochs)
            if self.stop_if_eval_metric_diff_negative_in_x_of_y_epochs is not None
            else None
        )

        for ind_epoch, epoch in enumerate(self.training.run(self.no_epochs)):
            epoch.print()
            reco_imgs_dict = (
                self._merge_patches() if (ind_epoch - 1) in self.merge_epochs else None
            )  # do not count first epoch (initial free energy evaluation)
            barrier()
            eval_metric_dict = (
                self._compute_eval_metric(reco_imgs_dict)
                if (ind_epoch - 1) in self.merge_epochs
                else None
            )
            barrier()
            self.viz_epoch(
                free_energy=epoch._results["train_F"],
                reco_imgs_dict=reco_imgs_dict,
                eval_metric_dict=eval_metric_dict,
            )
            self.log_reco_epoch(reco_imgs_dict, eval_metric_dict)
            barrier()

            if (
                stopping_criterion is not None
                and eval_metric_dict is not None
                and stopping_criterion.check(eval_metric_dict["means"])
            ):
                pprint("Will stop now")
                break
            if eval_metric_dict is not None:
                pprint("\t" + f'{self.eval_metric_name}:{eval_metric_dict["means"]:.1f} dB')

        if self.interactive:
            plt.ioff()

        self._check_remove_h5_files()
