import argparse
import logging
import os
import sys
from pathlib import Path
import torch
from os.path import dirname, realpath
from sbsep.callback import CallbackSimple
from sbsep.config import load_yaml, BNNConfig, BNNConfigFactory
from sbsep.cbnn import CBNN
from sbsep.engine import InferenceEngine
from sbsep.plotting import plot_loss, plot_prediction
from sbsep.sample import compute_pred

logger = logging.getLogger(__name__)

torch.set_default_dtype(torch.float64)


class TestCBNNInference:
    def __init__(self, config: BNNConfig):
        super().__init__()
        self.model = CBNN(config, observe=True)

        self.seed = 17

        self.nepochs = config.train.nepochs
        self.xa, self.xb = config.architecture.domain
        self.xs = config.table.xs
        self.ys = config.table.ys

        if len(self.xs.shape) == 1:
            self.xs = self.xs.unsqueeze(-1)

        if len(self.ys.shape) == 1:
            self.ys = self.ys.unsqueeze(-1)

        self.data_folder = "../data"
        self.fig_folder = "./figs"

        self.data_folder = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), self.data_folder
        )
        self.fig_folder = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), self.fig_folder
        )

        path = Path(self.data_folder)
        path.mkdir(parents=True, exist_ok=True)
        path = Path(self.fig_folder)
        path.mkdir(parents=True, exist_ok=True)

        self.color = "r"
        self.callback = CallbackSimple()
        self.ie = InferenceEngine(
            self.model,
            (self.xs, self.ys),
            self.nepochs,
            self.callback,
            self.data_folder,
        )


def main(config, save_evidence):
    ti = TestCBNNInference(config)
    ti.ie.infer(save_evidence=save_evidence)

    plot_loss(ti.callback.param_history, fig_folder=ti.fig_folder, prefix=ti.model.name)

    prediction = compute_pred(ti.model, nsamples=200, guide_based=True)

    ax = plot_prediction(
        data=(ti.xs, ti.ys),
        predicted_dist=prediction,
        label="observe fit",
        name=ti.model.name,
    )

    bnn_loaded = CBNN.load_from_file(
        ti.data_folder,
        ti.model.name,
        load_var_init_as_prior=True,
        observe=False,
        debug=True,
    )

    prediction = compute_pred(
        bnn_loaded,
        nsamples=1000,
    )

    plot_prediction(
        data=(ti.xs, ti.ys),
        predicted_dist=prediction,
        label="loaded",
        name=f"pre_{bnn_loaded.name}",
        ax=ax,
        fig_folder=ti.fig_folder,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    cpath = dirname(realpath(__file__))

    logging.basicConfig(
        # filename="train_cbnn.log",
        format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        filemode="w",
        stream=sys.stdout,
    )

    parser.add_argument("--config-path", type=str, help="path to yaml config")
    parser.add_argument(
        "--evidence", action="store_true", help="user to save and plot evidence"
    )
    args = parser.parse_args()

    config_file = load_yaml(args.config_path)
    if "cbnns" in config_file:
        for conf in config_file["cbnns"]:
            config = BNNConfigFactory.get_bnn_config(conf)
            if config.train.pretrain:
                main(config, args.evidence)
    else:
        if hasattr(config_file, "items"):
            config = BNNConfigFactory.get_bnn_config(config_file)
            main(config, args.evidence)
