import os
from pathlib import Path
import unittest
import matplotlib.pyplot as plt
import seaborn as sns
from sbsep.util import get_data_random
from sbsep.sb import SBMixture
import pyro
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, Predictive
from pyro.optim import Adam, SGD, ClippedAdam
from sbsep.callback import CallbackSimple
from tqdm.auto import trange
from matplotlib import pyplot
from sbsep.config import SBConfigFactory, load_yaml


class TestSB(unittest.TestCase):
    sns.set_style("darkgrid")
    figs_folder = "./figs"
    cpath = os.path.dirname(os.path.realpath(__file__))
    figpath = os.path.join(cpath, figs_folder)
    path = Path(cpath)
    path.mkdir(parents=True, exist_ok=True)

    seed = 13
    coordinate, tsignal, composition, _ = get_data_random(5000)

    config_file = load_yaml(os.path.join(cpath, "./conf/sb_config.yaml"))
    config = SBConfigFactory.get_sb_config(config_file)
    folder = os.path.join(cpath, "../data")

    def test_main(self):
        # creating the SBMixture object
        sbm = SBMixture(self.config, self.folder)

        # sbm = SBMixture(1, 1, 2)
        coordinate, tsignal, composition, _ = get_data_random(5000)
        scoord = coordinate[..., :1]
        pcoord = coordinate[..., 1:]

        sbm.forward(scoord, pcoord, tsignal)
        sbm.guide(scoord, pcoord, tsignal)

    @unittest.skip("")
    def test_inference(self, nepochs=100):
        sbm = SBMixture(self.config, self.folder)
        coordinate, tsignal, composition = get_data_random(1000, self.seed)

        scoord = coordinate[..., :1]
        pcoord = coordinate[..., 1:]

        self.scoord = sbm.scaler_xcoord.scale(scoord)
        self.pcoord = sbm.scaler_ecoord.scale(pcoord)

        sbm.init_normw(os.path.join(self.cpath, "../run/post_params.pkl"))

        self.color = "r"

        pyro.clear_param_store()
        pyro.set_rng_seed(self.seed)
        optimizer = pyro.optim.ClippedAdam({"lr": 1e-4})

        if nepochs:
            self.n_epochs = nepochs

        self.callback = CallbackSimple()
        ls = []
        csvi = SVI(
            sbm,
            sbm.guide,
            optimizer,
            loss=Trace_ELBO(),
        )

        bar = trange(self.n_epochs)
        for epoch in bar:
            loss = csvi.step(self.scoord, self.pcoord, tsignal)
            ls += [loss]
            self.callback(epoch, ls)
            bar.set_postfix(loss=f"{loss / self.scoord.shape[0]:.3f}")

        self.plot_losses(ls)

    def plot_losses(self, losses):
        pyplot.figure(figsize=(6, 3), dpi=100).set_facecolor("white")
        pyplot.plot(losses)
        pyplot.xlabel("iters")
        pyplot.ylabel("loss")
        pyplot.yscale("log")
        pyplot.title("Convergence of SVI")
        plt.savefig(os.path.join(self.figpath, f"sb_loss.pdf"), dpi=200)
        plt.close()


if __name__ == "__main__":
    unittest.main()
