import unittest
import torch
import pyro
from os.path import join, dirname, realpath
import pathlib
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam, SGD
import numpy as np
from pyro import poutine
from sbsep.bnn import BNN
from sbsep.callback import Callback

import matplotlib.pyplot as plt
import seaborn as sns


class TestBNNInference(unittest.TestCase):
    figs_folder = "./figs"
    cpath = join(dirname(realpath(__file__)), figs_folder)
    path = pathlib.Path(cpath)
    path.mkdir(parents=True, exist_ok=True)

    nspace = 92
    sigma = 0.2
    k = 5
    plot = True

    n_epochs = 1000
    seed = 13

    torch.manual_seed(seed)

    pi = torch.acos(torch.zeros(1)).item() * 2
    xa, xb = 0.0, pi

    xdata = xa + (xb - xa) * torch.rand(nspace)
    ydata = torch.cos(k * xdata) + sigma * torch.randn_like(xdata)
    xdata, ydata = xdata.unsqueeze(-1), ydata.unsqueeze(-1)
    data = xdata, ydata
    dims = (1, 16, 32, 16, 1)
    name = "a"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dir_path = dirname(realpath(__file__))

        if self.plot:
            plt.style.use("seaborn")
            plt.figure(figsize=(5, 5))
            plt.scatter(self.xdata.detach().numpy(), self.ydata.detach().numpy())

            plt.tick_params(axis="both", labelsize=15)
            plt.savefig(join(self.cpath, "bnn_inference_data.pdf"))

    def test_poutine_block(self):
        name = self.name
        bnn = BNN(dims=self.dims, name=name)
        trace_v1 = pyro.poutine.trace(
            poutine.block(
                bnn.guide,
                expose=[f"{name}#sigma_alpha", f"{name}#sigma_beta", "a#sigma", "data"],
            )
        )

        tv1 = trace_v1.get_trace(self.data)
        print(tv1.nodes.keys())

        to_hide = [f"{name}#weight_loc_{k}" for k in range(len(self.dims))]
        to_hide += [f"{name}#bias_loc_{k}" for k in range(len(self.dims))]
        to_hide += [f"{name}#weight_scale_{k}" for k in range(len(self.dims))]
        to_hide += [f"{name}#bias_scale_{k}" for k in range(len(self.dims))]

        trace_v1 = pyro.poutine.trace(poutine.block(bnn.guide, hide=to_hide))

        tv1 = trace_v1.get_trace(self.data)
        print(tv1.nodes.keys())

    # @unittest.skip("")
    def test_inference(self):
        pyro.set_rng_seed(self.seed)
        name = "a"
        bnn = BNN(dims=self.dims, name=name)

        pyro.clear_param_store()

        optimizer = pyro.optim.ClippedAdam({"lr": 1e-2})

        callback = Callback(
            check_convergence_every=10,
            convergence_window=20,
            tol_variation_parameters=1e-3,
            max_steps=1000,
            model=bnn,
            data=self.data,
            # early_stop=False,
            early_stop=True,
            n_warmup=50,
            verbose=True,
        )

        ls = []

        svi = SVI(
            bnn,
            poutine.block(
                bnn.guide, hide=[f"{name}#sigma_alpha", f"{name}#sigma_beta"]
            ),
            optimizer,
            loss=Trace_ELBO(),
        )

        for epoch in range(self.n_epochs):
            loss = svi.step(self.data)
            if np.isnan(loss):
                raise ValueError("nan loss detected")
            ls += [loss]
            model_converged = callback(epoch, ls)

        to_hide = [f"{name}#weight_loc_{k}" for k in range(len(self.dims))]
        to_hide += [f"{name}#bias_loc_{k}" for k in range(len(self.dims))]
        to_hide += [f"{name}#weight_scale_{k}" for k in range(len(self.dims))]
        to_hide += [f"{name}#bias_scale_{k}" for k in range(len(self.dims))]

        svi = SVI(
            bnn,
            poutine.block(bnn.guide, hide=to_hide),
            optimizer,
            loss=Trace_ELBO(),
        )

        for epoch in range(self.n_epochs):
            loss = svi.step(self.data)
            if np.isnan(loss):
                raise ValueError("nan loss detected")
            ls += [loss]
            model_converged = callback(self.n_epochs + epoch, ls)

        # optimizer = pyro.optim.ClippedAdam({"lr": 1e-3})
        #
        # svi = SVI(
        #     bnn,
        #     poutine.block(
        #         bnn.guide, hide=[f"{name}#sigma_alpha", f"{name}#sigma_beta"]
        #     ),
        #     optimizer,
        #     loss=Trace_ELBO(),
        # )
        #
        # for epoch in range(self.n_epochs):
        #     loss = svi.step(self.data)
        #     if np.isnan(loss):
        #         raise ValueError("nan loss detected")
        #     ls += [loss]
        #     model_converged = callback(2*self.n_epochs + epoch, ls)

        if self.plot:
            # plot loss
            plt.figure(figsize=(5, 5))
            plt.plot(ls)
            plt.xlabel("epoch", fontsize=20)
            plt.ylabel("loss", fontsize=20)
            plt.tick_params(axis="both", labelsize=15)

            plt.yscale("log")
            plt.savefig(join(self.cpath, "bnn_inference_loss.pdf"))

        predictive = Predictive(bnn, guide=bnn.guide, num_samples=1000)
        x_test = torch.linspace(-0.1, self.pi + 0.1, 30).unsqueeze(-1)
        data_test = x_test, None
        preds = predictive(data_test)

        if self.plot:
            # plot parameter history
            param_history = {
                k: [x.cpu().numpy() for x in v]
                for k, v in callback.param_history.items()
            }
            plt.figure(figsize=(5, 5))

            plt.plot(param_history[f"{name}#sigma_alpha"])
            plt.plot(param_history[f"{name}#sigma_beta"])
            plt.savefig(join(self.cpath, "bnn_inference_param_sigma_alpha_beta.pdf"))

            plt.figure(figsize=(5, 5))
            sns.displot(preds[f"{name}#sigma"].detach().numpy())
            plt.savefig(join(self.cpath, "bnn_inference_param_sigma_dist.pdf"))

            dummy = [x[0] for x in param_history[f"{name}#weight_loc_0"]]
            plt.figure(figsize=(5, 5))
            plt.plot(dummy)
            plt.savefig(join(self.cpath, "bnn_inference_param_weight_loc.pdf"))

        if self.plot:
            # plot prediction
            y_pred = preds["obs"].detach().squeeze(-1).numpy()
            y_means = y_pred.mean(axis=0)
            y_stds = y_pred.std(axis=0)

            fig, ax = plt.subplots(figsize=(10, 10))
            ax.plot(self.xdata, self.ydata, "o", markersize=5)
            ax.plot(x_test.squeeze(-1), y_means)
            ax.set_ylim(-1.5, 1.5)

            plt.xlabel("x", fontsize=20)
            plt.ylabel("y", fontsize=20)
            plt.tick_params(axis="both", labelsize=15)
            ya = 5
            plt.ylim(-ya, ya)
            ax.fill_between(
                x_test.squeeze(-1),
                y_means - 2 * y_stds,
                y_means + 2 * y_stds,
                alpha=0.5,
            )
            plt.savefig(join(self.cpath, "bnn_inference_prediction.pdf"))


if __name__ == "__main__":
    unittest.main()
