import os
import sys
import argparse
from os.path import dirname, realpath
from pathlib import Path
import torch
from sbsep.util import get_data_random
from sbsep.sb import SBMixture
from sbsep.callback import CallbackSimple
from sbsep.config import SBConfig, load_yaml, SBConfigFactory
from sbsep.engine import InferenceEngine
from sbsep.plotting import plot_loss, plot_prediction
import logging
from sbsep.sample import compute_pred

torch.set_default_tensor_type(torch.DoubleTensor)


class TestSB:
    def __init__(self, config: SBConfig, folder: str):
        super().__init__()
        self.seed = 17

        # creating the SBMixture object
        self.model = SBMixture(config, folder)

        # getting data from get_data_random
        (
            all_coord,
            signal_total,
            (self.parray_gt, self.spectrum_signal_gt, self.spectrum_bg_gt),
            (self.sarray_gt, self.norm_space_gt, self.pa_gt, self.pb_gt),
        ) = get_data_random(
            config.modeling.synthetic_nsample,
            self.seed,
            self.model.include_space_scaling,
        )

        # setting coordinates and total signal for the inference
        self.scoord = all_coord[..., :1]
        self.pcoord = all_coord[..., 1:]
        self.tsignal = signal_total

        self.data_folder = "../data"
        self.fig_folder = "./figs"

        self.data_folder = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), self.data_folder
        )
        self.fig_folder = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), self.fig_folder
        )

        path = Path(self.data_folder)
        path.mkdir(parents=True, exist_ok=True)
        path = Path(self.fig_folder)
        path.mkdir(parents=True, exist_ok=True)

        self.color = "r"
        self.callback = CallbackSimple(
            positive_flag_window=config.train.positive_flag_window
        )
        self.ie = InferenceEngine(
            model=self.model,
            data=(self.scoord, self.pcoord, self.tsignal),
            nepochs=config.train.nepochs,
            callback=self.callback,
            data_folder=self.data_folder,
            lr_factor=config.train.lr_factor,
            lr0=config.train.lr0,
        )


def main(config, folder, save_evidence):
    ti = TestSB(config, folder)
    hide_sites = []
    if not config.modeling.normalized_space:
        hide_sites += ti.model.tsignal_norm.guess_parameter_names()

    ti.ie.infer(hide_sites=hide_sites, save_evidence=save_evidence)
    print("1st run done")

    config.train.lr0 *= 0.1
    ti.ie.infer(hide_sites=[], save_evidence=save_evidence, clear_store=False)
    print("2st run done")

    plot_loss(ti.callback.param_history, prefix=ti.model.name, fig_folder=ti.fig_folder)

    ti.model.vec_signal.debug = True

    # sample signal/noise spectral predictions
    prediction = compute_pred(ti.model.vec_signal, nsamples=200, guide_based=True)
    prediction_prior = compute_pred(
        ti.model.vec_signal, nsamples=200, guide_based=False
    )

    # plot signal/noise spectral predictions
    ax = plot_prediction(
        predicted_dist=prediction,
        data=[(ti.parray_gt, ti.spectrum_bg_gt), (ti.parray_gt, ti.spectrum_signal_gt)],
        label="posterior",
        xlabel="spectral coordinate",
        hatch="/"
        # plot_envelope=False
    )

    ax = plot_prediction(
        predicted_dist=prediction_prior,
        label="prior",
        name=f"sb_{ti.model.vec_signal.name}",
        fig_folder=ti.fig_folder,
        xlabel="spectral coordinate",
        ax=ax,
        # alpha_envelope=0.1,
        plot_envelope=False,
    )

    ti.model.dir_weight.debug = True

    prediction = compute_pred(ti.model.dir_weight, nsamples=200, guide_based=True)
    prediction_prior = compute_pred(
        ti.model.dir_weight, nsamples=200, guide_based=False
    )

    ax = plot_prediction(
        predicted_dist=prediction,
        data=[(ti.sarray_gt, ti.pa_gt), (ti.sarray_gt, ti.pb_gt)],
        label="posterior",
        hatch="/",
    )

    ax = plot_prediction(
        predicted_dist=prediction_prior,
        label="prior",
        name=f"sb_{ti.model.dir_weight.name}",
        fig_folder=ti.fig_folder,
        ax=ax,
        alpha_envelope=0.1,
    )

    if ti.model.include_space_scaling:
        ti.model.tsignal_norm.debug = True

        prediction = compute_pred(
            ti.model.tsignal_norm,
            nsamples=200,
        )

        ax = plot_prediction(
            predicted_dist=prediction,
            data=(ti.sarray_gt, ti.norm_space_gt),
            label="observe fit",
            name=f"sb_{ti.model.tsignal_norm.name}",
            fig_folder=ti.fig_folder,
        )

    # ti.model.dir_norm.debug = True
    #
    # prediction = compute_pred(
    #     ti.model.dir_norm,
    #     nsamples=200,
    #     guide_based=True
    # )
    #
    # ax = plot_prediction(
    #     predicted_dist=prediction,
    #     # data=(ti.sarray_gt, ti.),
    #     label="observe fit",
    #     name=f"sb_{ti.model.dir_norm.name}",
    #     fig_folder=ti.fig_folder,
    # )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    cpath = dirname(realpath(__file__))

    logging.basicConfig(
        # filename="train_sb.log",
        format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        filemode="w",
        stream=sys.stdout,
    )

    parser.add_argument("--config-path", type=str, help="path to yaml config")
    parser.add_argument(
        "--pretrained-folder", type=str, help="folder to pretrained pickles nns"
    )

    parser.add_argument(
        "--evidence", action="store_true", help="user to save and plot evidence"
    )

    args = parser.parse_args()

    config_file = load_yaml(args.config_path)

    pretrained_folder = os.path.join(
        os.path.dirname(os.path.realpath(__file__)), args.pretrained_folder
    )

    sb_config = SBConfigFactory.get_sb_config(config_file)
    main(sb_config, pretrained_folder, args.evidence)
