import os
import sys
import argparse
from os.path import dirname, realpath
from pathlib import Path
import torch
from sbsep.cbnn import CBNN
from sbsep.plotting import (
    plot_loss,
    plot_prediction,
    plot_basic,
    plot_param,
    plot_weights_dist,
)
import logging
from copy import deepcopy
from typing import Optional, List
from sbsep.sample import compute_pred
import gzip
import pickle
from sbsep.config import load_yaml, SBConfigFactory

torch.set_default_tensor_type(torch.DoubleTensor)


logger = logging.getLogger(__name__)


class Plotter:
    def __init__(self):
        super().__init__()
        self.seed = 17

        self.data_folder = "../data"
        self.fig_folder = "./figs"

        self.data_folder = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), self.data_folder
        )
        self.fig_folder = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), self.fig_folder
        )

        path = Path(self.data_folder)
        path.mkdir(parents=True, exist_ok=True)
        path = Path(self.fig_folder)
        path.mkdir(parents=True, exist_ok=True)

        self.color = "r"


def main(name, hpath, config, tentative_steps: Optional[List[int]] = None):
    nhistories = 8
    pl = Plotter()

    hpath_full = os.path.join(os.path.expanduser(hpath), f"{name}_history.pkl.gz")

    with gzip.open(hpath_full) as fp:
        history = pickle.load(fp)

    # scalars = [k for k in history["params"] if "#" not in k]
    # plot_param(history["params"], ["dir_norm"], pl.fig_folder, prefix=name)
    plot_param(history["params"], [], pl.fig_folder, prefix=name)
    plot_loss(history["params"], pl.fig_folder, prefix=name)

    bnn_rvs = history["cbnn history"]

    for lname, hist in bnn_rvs.items():
        lconfig = config.cbnns[lname]
        bnn_loaded = CBNN(lconfig, observe=False, debug=True)

        ax = None
        if tentative_steps is None:
            nsteps = len(hist.keys())
            hstep = max([int(nsteps / nhistories), 1])
            tentative_steps = list(hist.keys())[::hstep]
        for step in tentative_steps:
            if step not in hist:
                logger.error(f"step {step} not in history. History keys {hist.keys()}")
                continue
            vguess = deepcopy(hist[step])
            bnn_loaded.state.var_init_value = vguess
            prediction = compute_pred(bnn_loaded, nsamples=200, guide_based=True)

            ax = plot_prediction(
                predicted_dist=prediction,
                label=f"fit step {step}",
                name=f"history_{name}_{lname}",
                fig_folder=pl.fig_folder if step == tentative_steps[-1] else None,
                xlabel="spectral coordinate",
                ax=ax,
                plot_envelope=True,
            )

            # plot_weights_dist
    # ax = plot_prediction(
    #     predicted_dist=prediction,
    #     data=[(ti.parray_gt, ti.spectrum_bg_gt), (ti.parray_gt, ti.spectrum_signal_gt)],
    #     label="observe fit",
    #     name=f"sb_{ti.model.vec_signal.name}",
    #     fig_folder=ti.fig_folder,
    #     xlabel="spectral coordinate",
    #     # plot_envelope=False
    # )

    # ti.model.dir_norm.debug = True
    #
    # prediction = compute_pred(
    #     ti.model.dir_norm,
    #     nsamples=200,
    #     guide_based=True
    # )
    #
    # ax = plot_prediction(
    #     predicted_dist=prediction,
    #     # data=(ti.sarray_gt, ti.),
    #     label="observe fit",
    #     name=f"sb_{ti.model.dir_norm.name}",
    #     fig_folder=ti.fig_folder,
    # )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    cpath = dirname(realpath(__file__))

    logging.basicConfig(
        format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        filemode="w",
        stream=sys.stdout,
    )

    parser.add_argument("--history-folder", type=str, help="path to history pgz")
    parser.add_argument("--config-path", type=str, help="path to yaml config")
    parser.add_argument("--model-name", type=str, help="model name")
    parser.add_argument("--plot-steps", nargs="+", type=int)

    args = parser.parse_args()

    config_file = load_yaml(args.config_path)
    config = SBConfigFactory.get_sb_config(config_file)

    args = parser.parse_args()

    main(
        args.model_name,
        hpath=args.history_folder,
        config=config,
        tentative_steps=args.plot_steps,
    )
