import logging
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam
from sbsep.callback import CallbackSimple
from pyro import poutine
import numpy as np
import os
import gzip
import pickle

logger = logging.getLogger(__name__)


class InferenceEngine:
    def __init__(
        self,
        model,
        data,
        nepochs,
        callback: CallbackSimple,
        data_folder,
        seed=17,
        nwarmup=50,
        lr0=1e-2,
        lr_factor=0.8,
    ):
        self.model = model
        self.data = data
        self.nepochs = nepochs
        self.callback = callback
        self.data_folder = data_folder
        self.seed = seed
        self.nwarmup = nwarmup
        self.lr0 = lr0
        self.lr_factor = lr_factor
        self.cepoch = 0
        self.eon = 0
        self.parameter_history = CallbackSimple()

    def run_atomic(self, lr, cepoch=0, hide_sites=(), save_evidence=False):
        """

        :param lr:
        :param cepoch:
        :param hide_sites: will be passed to poutine.block
        :param save_evidence:
        :return:
        """
        logger.info(f"| ra lr {lr:.3e} <<<")
        optimizer = ClippedAdam({"lr": lr})

        csvi = SVI(
            self.model,
            poutine.block(
                self.model.guide,
                hide=hide_sites,
            ),
            optimizer,
            loss=Trace_ELBO(),
        )

        bar = range(cepoch, (self.eon + 1) * self.nepochs)
        self.callback.counter = 0
        for epoch in bar:
            loss = csvi.step(*self.data)
            self.model.save_state_to_history(epoch)
            if np.isnan(loss):
                logger.error(" nan loss detected")
                return epoch, loss

            global_metrics = {"loss": loss, "rms": self.model.rms}

            if save_evidence:
                global_metrics["evidence"] = self.model.get_data_evidence(
                    self.data, 100
                )
            flag_positive, flag_small = self.callback(epoch, global_metrics)

            if epoch > self.nwarmup and (flag_small or flag_positive):
                logger.info(f"| epoch: {epoch:<5} | loss >>> {loss:.3e} <<<")
                logger.info(
                    f"| learning: {lr:.3e} | flag_positive: {int(flag_positive)} | flag_small: {int(flag_small)}"
                )
                logger.info(
                    f"| loss tail:  "
                    + f" | ".join(
                        [
                            f"{x:>+.3e}"
                            for x in self.callback.param_history["loss"][
                                -self.callback.small_flag_window :
                            ]
                        ]
                    )
                )
                logger.info(
                    f"| rloss tail: "
                    + f" | ".join(
                        [
                            f"{x:>+.3e}"
                            for x in self.callback.param_history["rloss"][
                                -self.callback.small_flag_window :
                            ]
                        ]
                    )
                )

                break
        return epoch, loss

    def infer(self, hide_sites=(), clear_store=True, save_evidence=False):
        """

        :param hide_sites: will be passed to poutine.block
        :param clear_store:
        :param save_evidence:
        :return:
        """
        if clear_store:
            pyro.clear_param_store()
        pyro.set_rng_seed(self.seed)

        lr = self.lr0
        self.cepoch, closs = self.run_atomic(
            lr, self.cepoch, hide_sites=hide_sites, save_evidence=save_evidence
        )
        if np.isnan(closs):
            return self.parameter_history, self.model

        while self.cepoch < (self.eon + 1) * self.nepochs - 1:
            self.cepoch, closs = self.run_atomic(
                lr, self.cepoch, hide_sites=hide_sites, save_evidence=save_evidence
            )
            if np.isnan(closs):
                break
            lr *= self.lr_factor

        self.eon += 1
        state = {
            "cbnn history": self.model.history,
            "params": self.callback.param_history,
        }

        filename = os.path.join(
            os.path.expanduser(self.data_folder), f"{self.model.name}_history.pkl.gz"
        )
        with gzip.open(filename, "wb") as file:
            pickle.dump(state, file)

        self.model.save(self.data_folder)

        return self.parameter_history, self.model
