import pyro
from statsmodels.tsa.stattools import adfuller
import torch
import numpy as np
import logging

logger = logging.getLogger(__name__)


class Callback:
    def __init__(
        self,
        check_convergence_every=10,
        convergence_window=100,
        ave_window=10,
        n_warmup=200,
        max_steps=15000,
        log_n_times=10,
        tol_variation_parameters=0.03,
        name_param_variance_monitoring=None,
        model=None,
        data=None,
        early_stop=False,
        verbose=False,
    ):
        self.param_history = {}
        self.grad_history = {}
        self.check_convergence_every = check_convergence_every
        self.convergence_window = convergence_window
        self.ave_window = ave_window
        self.tol_variation_parameters = tol_variation_parameters
        self.max_steps = max_steps
        self.write_log_every = max_steps // log_n_times
        self.n_warmup = n_warmup
        assert self.convergence_window < self.n_warmup

        self.name_param_variance_monitoring = name_param_variance_monitoring
        self.mean_interval_loss_history = []
        self.mean_params = {}

        self.model = model
        self.data = data
        self.early_stop = early_stop
        self.adf = []

        self.verbose = verbose

    @staticmethod
    def init_param_history() -> dict:
        """
        Initialize the recording of parameters in the pyro PARAM_STORE during training

        Returns
        -------
        param_history:
            initialized param_history with parameters names from the pyro PARAM_STORE
        """
        param_history = {}
        params = pyro.get_param_store()
        for name in params.keys():
            param_history[name] = []
        param_history["loss"] = []
        return param_history

    def on_train_start(self):
        self.param_history = self.init_param_history()
        self.grad_history = self.init_param_history()
        self.mean_params = self.init_param_history()

        if self.name_param_variance_monitoring is None:
            self.name_param_variance_monitoring = tuple(self.param_history.keys())

    def record_state(self, losses):
        params = pyro.get_param_store()
        for name, value in params.items():
            value = value.detach().clone()
            self.param_history[name] += [value]
        self.param_history["loss"] += [torch.tensor(losses[-1])]
        # _, self.hooks = get_param_grad(self.grad_history, self.hooks)

    def get_metrics(self, losses):
        self.get_mean_parameters_values()
        # self.compute_adf(losses)
        # self.compute_hess()

    def compute_adf(self, losses):
        self.adf += [adfuller(losses[-self.convergence_window :])[:2]]

    def get_mean_parameters_values(self, window=True):
        if window is True:
            for param_name, v in self.param_history.items():
                self.mean_params[param_name] += [
                    torch.mean(torch.stack(v)[-self.convergence_window :], dim=0)
                ]
        else:
            for param_name, v in self.param_history.items():
                self.mean_params[param_name] += [
                    torch.mean(torch.stack(v)[-window:], dim=0)
                ]

    def optimal_parameters(self, window=True):
        optimal = dict()
        if window is True:
            for param_name, v in self.param_history.items():
                optimal[param_name] = torch.mean(
                    torch.stack(v)[-self.convergence_window :], dim=0
                )
        else:
            for param_name, v in self.param_history.items():
                optimal[param_name] = torch.mean(torch.stack(v)[-window:], dim=0)
        return optimal

    def convergence_reached(self):
        if self.early_stop:
            params_converged = False
            loss_trend_vanished = False

            relative_mean_grads = {
                k: abs((v[-self.ave_window] - v[-1]) / v[-1])
                / self.check_convergence_every
                / self.ave_window
                for k, v in self.mean_params.items()
                if k != "loss"
            }
            if all(
                [
                    (x < self.tol_variation_parameters).all()
                    for x in relative_mean_grads.values()
                ]
            ):
                params_converged = True

            # loss_adf_pvals = [y for x, y in self.adf][-5:]
            # if all([x < 1e-10 for x in loss_adf_pvals]):
            #     loss_trend_vanished = True

            if params_converged:
                return params_converged

            # if params_converged and loss_trend_vanished:
            #     if self.verbose:
            #         torch.set_printoptions(precision=3, sci_mode=True)
            #         relative_mean_grads_str = ";\n\t ".join(
            #             [f"{k} : {v}" for k, v in relative_mean_grads.items()]
            #         )
            #         loss_adf_pvals_str = ";\n\t ".join(
            #             [f"{x:.4g}" for x in loss_adf_pvals]
            #         )
            #         logger.info(
            #             f" Callback: relative_mean_grads: \n\t {relative_mean_grads_str}"
            #         )
            #         logger.info(f" Callback: last adf pvals: \n\t {loss_adf_pvals_str}")
            #     return params_converged and loss_trend_vanished
            else:
                return False
        else:
            return False

    def __call__(self, step, losses):
        model_converged = False
        if step == 0:
            self.on_train_start()

        self.record_state(losses)

        # if step % self.check_convergence_every == 0 and step > self.n_warmup:
        #     self.get_metrics(losses)
        # if (
        #     step % self.check_convergence_every == 0
        #     and step > self.n_warmup + self.check_convergence_every * self.ave_window
        # ):
        #     model_converged = self.convergence_reached()
        # if step % self.write_log_every == 0:
        #     logger.info(
        #         f" Step : {step} | ELBO : {torch.mean(torch.tensor(losses[-self.write_log_every:])) : .2f}"
        #     )

        return model_converged


class CallbackSimple:
    def __init__(self, positive_flag_window=2):
        self.averaging_window = 50
        self.warmup = 30
        self.param_history = self.init_param_history()
        self.eps = 1e-5
        self.positive_flag_window = positive_flag_window
        self.small_flag_window = 3
        self.counter = 0
        self.skip_morpheme = [
            "weight_loc_",
            "bias_loc_",
            "weight_scale_",
            "bias_scale_",
        ]

    def init_param_history(self) -> dict:
        param_history = {}
        params = pyro.get_param_store()
        for name in params.keys():
            if not any([morpheme in name for morpheme in self.skip_morpheme]):
                param_history[name] = []
        param_history["loss"] = []
        param_history["rms"] = []
        param_history["evidence"] = []
        param_history["aloss"] = []
        param_history["rloss"] = []
        param_history["raloss"] = []

        return param_history

    def on_train_start(self):
        self.param_history = self.init_param_history()

    def record_state(self, global_metrics):
        params = pyro.get_param_store()
        for name, value in params.items():
            if not any([morpheme in name for morpheme in self.skip_morpheme]):
                value = value.detach().clone()
                self.param_history[name] += [value]
        for k, value in global_metrics.items():
            self.param_history[k] += [value]

    def __call__(self, step, global_metrics: dict):
        """

        :param step:
        :param global_metrics: such as loss, evidence etc
        :return:
        """
        self.counter += 1
        if step == 0:
            self.on_train_start()
        self.record_state(global_metrics)
        if step % 20 == 0:
            logger.info(
                f"| epoch: {step:<5} | loss >>> {global_metrics['loss']:.3e} <<<"
            )

        if len(self.param_history["loss"]) > self.averaging_window + 1:
            self.param_history["aloss"] += [
                np.mean(self.param_history["loss"][-self.averaging_window :])
            ]
        if len(self.param_history["aloss"]) > 1:
            self.param_history["rloss"] += [
                (self.param_history["loss"][-1] - self.param_history["loss"][-2])
                / self.param_history["loss"][-1]
            ]

            self.param_history["raloss"] += [
                (self.param_history["aloss"][-1] - self.param_history["aloss"][-2])
                / self.param_history["aloss"][-1]
            ]

        flag_positive = False
        flat_small = False
        if self.counter > self.warmup and len(self.param_history["rloss"]) > max(
            [self.positive_flag_window, self.small_flag_window]
        ):
            flag_positive = all(
                [
                    x > 0
                    for x in self.param_history["rloss"][-self.positive_flag_window :]
                ]
            )
            flat_small = all(
                [
                    abs(x) < self.eps
                    for x in self.param_history["rloss"][-self.small_flag_window :]
                ]
            )
        return flag_positive, flat_small
