import os
from pathlib import Path
import unittest
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.distributions import constraints
import pyro
from pyro import distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, Predictive
from pyro.nn import PyroModule, PyroSample
from pyro.optim import Adam, SGD, ClippedAdam
from pyro import poutine
from sbsep.bnn import BNN
from sbsep.callback import CallbackSimple
from sbsep.dist import bnn_helper


class SimpleBNN(PyroModule):
    def __init__(self, dims=(1, 32, 1), name="a"):
        super().__init__()
        self.name = name
        self.dims = dims
        self.nlayers = len(dims) - 1
        self.activation = nn.ReLU()
        self.nsigma = torch.tensor(1e-4)

        self.prior_params = []
        for k, (da, db) in enumerate(zip(dims, dims[1:])):
            weight_loc = torch.zeros(da, db)
            weight_scale = torch.ones(da, db)
            bias_loc = torch.zeros(db)
            bias_scale = torch.ones(db)
            self.prior_params += [(weight_loc, weight_scale, bias_loc, bias_scale)]

    def forward(self, data):
        x, y = data
        sigma = pyro.sample(
            f"{self.name}#sigma",
            dist.InverseGamma(torch.tensor(30.0), torch.tensor(1.45)),
        )
        mu = bnn_helper(self.prior_params, x, name=self.name)
        obs = pyro.sample("obs", dist.Normal(mu, sigma).to_event(1), obs=y)
        return obs

    def guide(self, data):
        x, y = data
        dims = self.dims
        sigma_alpha = pyro.param(
            f"{self.name}#sigma_alpha",
            torch.tensor(100.0),
            constraint=constraints.positive,
        )
        sigma_beta = pyro.param(
            f"{self.name}#sigma_beta",
            torch.tensor(1.0),
            constraint=constraints.positive,
        )

        sigma = pyro.sample(
            f"{self.name}#sigma", dist.InverseGamma(sigma_alpha, sigma_beta)
        )

        gparams = []

        for k, (da, db) in enumerate(zip(dims, dims[1:])):
            guess_weight_loc = pyro.param(
                f"{self.name}#weight_loc_{k}", torch.zeros(da, db)
            )
            guess_weight_scale = pyro.param(
                f"{self.name}#weight_scale_{k}",
                0.01 * torch.ones(da, db),
                constraint=constraints.positive,
            )
            guess_bias_loc = pyro.param(f"{self.name}#bias_loc_{k}", torch.zeros(db))
            guess_bias_scale = pyro.param(
                f"{self.name}#bias_scale_{k}",
                0.01 * torch.ones(db),
                constraint=constraints.positive,
            )
            gparams += [
                (guess_weight_loc, guess_weight_scale, guess_bias_loc, guess_bias_scale)
            ]
        mu = bnn_helper(gparams, x, name=self.name)


class TestSimpleBNNInference:

    nspace = 92
    sigma = 0.2
    k = 5
    plot = True

    n_epochs = 1000
    seed = 13

    torch.manual_seed(seed)

    pi = torch.acos(torch.zeros(1)).item() * 2
    xa, xb = 0.0, pi

    xdata1 = xa + (xb - xa) * torch.rand(nspace)
    xdata2 = xdata1

    # first dataset
    ydata1 = torch.cos(k * xdata1) + sigma * torch.randn_like(xdata1)
    xdata1, ydata1 = xdata1.unsqueeze(-1), ydata1.unsqueeze(-1)
    data1 = xdata1, ydata1

    # trying a simpler function
    ydata2 = k * torch.exp(-xdata2 / 1.0) + sigma * torch.randn_like(xdata2)
    xdata2, ydata2 = xdata2.unsqueeze(-1), ydata2.unsqueeze(-1)
    # second dataset from the simpler function
    data2 = xdata2, ydata2

    dims1 = (1, 16, 32, 16, 1)

    # try a simpler BNN architecture
    dims2 = (1, 8, 8, 1)

    def __init__(self, name, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bnn = None
        self.name = name

    def test_inference(self, nepochs=100):
        name = self.name
        if name != "simplebnn":
            self.bnn = BNN(dims=self.dims2, name=name)
            self.color = "b"
        else:
            self.bnn = SimpleBNN(dims=self.dims2, name=name)
            self.color = "r"

        bnn = self.bnn

        pyro.clear_param_store()
        pyro.set_rng_seed(self.seed)
        optimizer = pyro.optim.ClippedAdam({"lr": 1e-2})

        if nepochs:
            self.n_epochs = nepochs

        self.callback = CallbackSimple()

        ls = []

        csvi = SVI(
            bnn,
            poutine.block(
                bnn.guide, hide=[f"{name}#sigma_alpha", f"{name}#sigma_beta"]
            ),
            optimizer,
            loss=Trace_ELBO(),
        )

        for epoch in range(self.n_epochs):
            loss = csvi.step(self.data2)
            global_metrics = {"loss": loss}
            self.callback(epoch, global_metrics)

        return ls

    def plot_loss(self, ls, ax=None):
        if ax is None:
            fig, ax = plt.subplots(figsize=(10, 5))
        plt.plot(ls, color=self.color)

        plt.xlabel("epoch", fontsize=20)
        plt.ylabel("loss", fontsize=20)
        plt.tick_params(axis="both", labelsize=15)
        plt.yscale("log")
        return ax

    def compute_pred(self, nsamples=1000):
        # compute predictions
        predictive_simple = Predictive(
            self.bnn, guide=self.bnn.guide, num_samples=nsamples
        )

        self.x_test = torch.linspace(-0.1, self.pi + 0.1, 30).unsqueeze(-1)
        data_test = self.x_test, None
        preds = predictive_simple(data_test)
        return (preds,)

    def plot_preds(self, preds, ax=None):
        # plot predictions
        y_pred = preds["obs"].detach().squeeze(-1).numpy()
        y_means = y_pred.mean(axis=0)
        y_stds = y_pred.std(axis=0)

        if ax is None:
            fig, ax = plt.subplots(figsize=(7, 7))

        ax.plot(self.xdata2, self.ydata2, "o", markersize=5, alpha=0.5)

        ax.plot(self.x_test.squeeze(-1), y_means, color=self.color, label=self.name)

        ax.set_ylim(-1.5, 1.5)

        plt.xlabel("x", fontsize=20)
        plt.ylabel("y", fontsize=20)
        plt.tick_params(axis="both", labelsize=15)
        ya = 6.0
        plt.ylim(-ya / 3.0, ya)

        ax.fill_between(
            self.x_test.squeeze(-1),
            y_means - 2 * y_stds,
            y_means + 2 * y_stds,
            alpha=0.5,
            color=self.color,
        )
        return ax


class TestEngine(unittest.TestCase):
    sns.set_style("darkgrid")
    figs_folder = "./figs"
    cpath = os.path.join(os.path.dirname(os.path.realpath(__file__)), figs_folder)
    path = Path(cpath)
    path.mkdir(parents=True, exist_ok=True)

    def test_main(self):
        n_inf = 200

        ti = TestSimpleBNNInference(name="simplebnn")
        loss = ti.test_inference(nepochs=n_inf)
        preds = ti.compute_pred(nsamples=100)

        ti2 = TestSimpleBNNInference(name="prevbnn")
        loss_prev = ti2.test_inference(nepochs=n_inf)
        preds_prev = ti2.compute_pred(nsamples=100)

        ax = ti.plot_loss(loss)
        ax = ti2.plot_loss(loss_prev, ax=ax)
        plt.savefig(os.path.join(self.cpath, "test_bnn_helper_loss.pdf"))

        ax = ti.plot_preds(preds[0])
        ax = ti2.plot_preds(preds_prev[0], ax=ax)
        plt.savefig(os.path.join(self.cpath, "test_bnn_helper_prediction.pdf"))


if __name__ == "__main__":
    unittest.main()
