import unittest
import torch
import pyro
from pathlib import Path
import os
from torch.distributions import constraints
from pyro import distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam, SGD
from pyro import poutine
import matplotlib.pyplot as plt

from sbsep.dist import BNNLayerDist, linear_
from test.util import generate_bnn_weights


class TestBNNLayerDist(unittest.TestCase):
    figs_folder = "./figs"
    cpath = os.path.join(os.path.dirname(os.path.realpath(__file__)), figs_folder)
    path = Path(cpath)
    path.mkdir(parents=True, exist_ok=True)
    name, dims, params, input_data = generate_bnn_weights()

    def test_linear_aux2(self):
        w_loc, w_scale, bias_loc, bias_scale = self.params[0]
        inp = torch.zeros((15, 77, self.dims[0]))
        r = linear_(inp, w_loc, bias_loc)
        assert r.shape == (15, 77, self.dims[1])

    def test_bnn_layer_dist(self):
        with pyro.plate("data", self.input_data.shape[0]):
            layer = BNNLayerDist(self.params[0], self.input_data).expand(
                (self.input_data.size(0),)
            )
            r = pyro.sample("bnn_layer", layer)
            assert r.shape == (self.input_data.shape[0], self.params[0][0].shape[-1])

    def test_bnn_layer_dist_logprob(self):
        layer = BNNLayerDist(self.params[0], self.input_data).expand(
            (self.input_data.size(0),)
        )

        ysample = torch.ones((self.input_data.size(0), self.dims[1]))
        y_logprob = layer.log_prob(ysample)
        assert y_logprob.shape == (self.input_data.shape[0],)

    def test_inference(self):
        torch.manual_seed(17)
        name = "b"
        dims = (3, 7)
        pp_dim = 30

        sigma = 0.2
        sigma_w = 0.1
        pi = torch.acos(torch.zeros(1)).item() * 2
        xa, xb = 0.0, pi

        nspace = dims[0] * pp_dim

        k = 3 + sigma_w * torch.randn(dims)

        xdata = (xa + (xb - xa) * torch.rand(nspace)).reshape(-1, dims[0])
        ydata = torch.einsum("...k,kj->...j", xdata, k) + sigma * torch.randn(
            pp_dim, dims[1]
        )
        data = xdata, ydata
        sxdata = torch.matmul(xdata, k)

        def model_simple(data):
            x, y = data
            sigma = pyro.sample(
                f"{name}#sigma",
                dist.InverseGamma(torch.tensor(30.0), torch.tensor(1.45)),
            )

            weight_loc = torch.zeros(dims)
            weight_scale = 2e0 * torch.ones(dims)
            bias_loc = torch.zeros(dims[1])
            bias_scale = 2e0 * torch.ones(dims[1])

            prior_params = (weight_loc, weight_scale, bias_loc, bias_scale)
            with pyro.plate("data", x.size(0)):
                my = pyro.sample(
                    f"{name}#layer_0",
                    BNNLayerDist(prior_params, x, final=False),
                )
                obs = pyro.sample("obs", dist.Normal(my, sigma).to_event(1), obs=y)
            return obs

        def guide_simple(data):
            x, y = data

            sigma_alpha = pyro.param(
                f"{name}#sigma_alpha",
                torch.tensor(100.0),
                constraint=constraints.positive,
            )
            sigma_beta = pyro.param(
                f"{name}#sigma_beta", torch.tensor(1.0), constraint=constraints.positive
            )

            sigma = pyro.sample(
                f"{name}#sigma", dist.InverseGamma(sigma_alpha, sigma_beta)
            )

            guess_weight_loc = pyro.param(f"{name}#weight_loc_0", torch.zeros(dims))
            guess_weight_scale = pyro.param(
                f"{name}#weight_scale_0",
                1e-2 * torch.ones(dims),
                constraint=constraints.positive,
            )
            guess_bias_loc = pyro.param(f"{name}#bias_loc_0", torch.zeros(dims[1]))
            guess_bias_scale = pyro.param(
                f"{name}#bias_scale_0",
                1e-2 * torch.ones(dims[1]),
                constraint=constraints.positive,
            )

            guide_params = (
                guess_weight_loc,
                guess_weight_scale,
                guess_bias_loc,
                guess_bias_scale,
            )

            with pyro.plate(f"data", x.size(0)):
                r = pyro.sample(
                    f"{name}#layer_0",
                    BNNLayerDist(guide_params, x, final=False),
                )

        ls_simple = []

        n_epochs = 1000
        optimizer = pyro.optim.ClippedAdam({"lr": 1e-2})
        pyro.clear_param_store()

        name_simple = "b"
        svi_simple = SVI(
            model_simple,
            poutine.block(
                guide_simple,
                hide=[f"{name_simple}sigma_alpha", f"{name_simple}sigma_beta"],
            ),
            optimizer,
            loss=Trace_ELBO(),
        )

        for epoch in range(n_epochs):
            loss_simple = svi_simple.step(data)

            ls_simple += [loss_simple]

        fig, ax = plt.subplots(figsize=(5, 5))

        ax.plot(ls_simple, color="red", label="simple")

        plt.tick_params(axis="both", labelsize=16)
        ax.set_xlabel("epochs", fontsize=16)
        ax.set_ylabel("loss", fontsize=16)
        ax.legend(loc="best", fontsize=16)
        ax.set_yscale("log")

        plt.savefig(os.path.join(self.cpath, "test_bnn_layer_loss.pdf"))

        params = pyro.get_param_store()
        weight_point = params["b#weight_loc_0"]
        bias_point = params["b#bias_loc_0"]

        predictive_simple_bnn = Predictive(
            model_simple, guide=guide_simple, num_samples=100
        )

        x_test = torch.stack([torch.linspace(xa, xb, 121)] * dims[0]).T
        # x_test = (xa + (xb - xa) * torch.rand(121*dims[0])).reshape(121, dims[0])

        data_test = x_test, None
        preds_simple_bnn = predictive_simple_bnn(data_test)

        ydim_towatch = 0

        sx_test = (
            torch.matmul(x_test, weight_point).detach().numpy()
            + bias_point.detach().numpy()
        )
        y_pred_simple_bnn = preds_simple_bnn["obs"].squeeze(0).detach().numpy()
        y_means_simple_bnn = y_pred_simple_bnn[..., ydim_towatch].mean(0)
        y_stds_simple_bnn = y_pred_simple_bnn[..., ydim_towatch].std(0)

        fig, ax = plt.subplots(figsize=(7, 7))
        ax.plot(sxdata[:, ydim_towatch], ydata[..., ydim_towatch], "x", markersize=5)
        ax.scatter(
            sx_test[:, ydim_towatch],
            y_means_simple_bnn,
            color="red",
            label="simple BNN",
            alpha=0.5,
        )

        ax.fill_between(
            sx_test[:, ydim_towatch],
            y_means_simple_bnn - 2 * y_stds_simple_bnn,
            y_means_simple_bnn + 2 * y_stds_simple_bnn,
            alpha=0.2,
            color="red",
        )
        plt.savefig(os.path.join(self.cpath, "test_bnn_layer_prediction.pdf"))


if __name__ == "__main__":
    unittest.main()
