""" Signal-to-noise ratio convergence for EinStein vs IWAE-Stein. Experiment from [1] section 4.
** Reference **
    [1] Tighter Variational Bounds are Not Necessarily Better, (2019)
        Tom Rainforth, Adam R. Kosiorek, Tuan Anh Le, Chris J. Maddison, Maximilian Igl, Frank Wood and Yee Whye Teh
"""
from functools import partial
import itertools

import matplotlib
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tqdm import tqdm

import jax
from jax import jit, lax, numpy as jnp, random

from experiments.steinvi_forces import SteinVIForces
import numpyro
from numpyro import param, plate, sample
from numpyro.contrib.einstein import RBFKernel, SteinVI
from numpyro.distributions import MultivariateNormal, Normal
from numpyro.handlers import seed, substitute, trace
from numpyro.infer import ELBO, RenyiELBO, Trace_ELBO
from numpyro.optim import SGD

jax.config.update("jax_platform_name", "cpu")
rng_seed = 42
latent_dim = 2
num_data_points = 64
num_draws = int(1e4)
loc_key, data_key, snr_key, offset_key = random.split(random.PRNGKey(rng_seed), 4)
true_loc = Normal(0, 1).sample(loc_key, (latent_dim,))

bias_perturb = Normal(0.0, 0.01).sample(offset_key, (latent_dim,))
bias_perturb = jnp.stack((bias_perturb, jnp.zeros_like(bias_perturb)))
matplotlib.rcParams.update({"font.size": 15})

min_exp = -8

min_drawn_value = 1.000001 * 10.0**min_exp  # above 10.**min_exp


def snr(grads):
    grad_est_mean = grads.mean(0)
    grad_est_std = grads.std(0)
    return jnp.abs(grad_est_mean / grad_est_std)


def model_global(obs):
    latent = sample(
        "latent",
        MultivariateNormal(true_loc, jnp.eye(latent_dim)),
    )
    with plate("data", obs.shape[0] if obs is not None else num_data_points):
        sample(
            "obs",
            MultivariateNormal(latent, jnp.eye(latent_dim)),
            obs=obs,
        )


def guide_global(obs):
    bias = param(
        "bias", lambda rng_key: Normal(jnp.zeros(latent_dim), 1.0).sample(rng_key)
    )
    sample(
        "latent",
        MultivariateNormal(bias, 3 / 2 * jnp.eye(latent_dim)),
    )


def model_local(obs):
    with plate("data", obs.shape[0] if obs is not None else num_data_points):
        latent = sample(
            "latent",
            MultivariateNormal(true_loc, jnp.eye(latent_dim)),
        )
        sample(
            "obs",
            MultivariateNormal(latent, jnp.eye(latent_dim)),
            obs=obs,
        )


def guide_local(obs):
    bias = param(
        "bias", lambda rng_key: Normal(jnp.zeros(latent_dim), 1.0).sample(rng_key)
    )
    with plate("data", obs.shape[0] if obs is not None else num_data_points):
        sample(
            "latent",
            MultivariateNormal(bias, 3 / 2 * jnp.eye(latent_dim)),
        )


class AvgELBO(ELBO):
    def __init__(self, inner_elbo: ELBO):
        self.inner_elbo = inner_elbo
        super().__init__(inner_elbo.num_particles)

    def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
        elbo = self.inner_elbo.loss(rng_key, param_map, model, guide, *args, **kwargs)
        return elbo / num_data_points

    def loss_with_mutable_state(self, *args, **kwargs):
        return super.loss_with_mutable_state(*args, **kwargs)


def estimate_snr():
    from jax.config import config

    config.update("jax_debug_nans", True)

    global snr_key, data_key

    with seed(rng_seed=data_key), substitute(data={"latent": true_loc}), trace() as tr:
        model_global(None)

    data = tr["obs"]["value"]

    stein = SteinVI(
        model_global,
        guide_global,
        SGD(step_size=1.0),
        AvgELBO(Trace_ELBO(num_particles=10_000)),
        RBFKernel(),
        num_particles=2,
    )
    rng_key, init_key, snr_key = random.split(snr_key, 3)

    res = stein.run(init_key, 1, data)
    draw_budget = [1, 10, 100, 1000]
    fig, ax = plt.subplots(1, 1)

    params = res.params
    params["bias"] = res.params["bias"] + bias_perturb

    trace_palette = itertools.cycle(sns.color_palette())
    for i, (name, loss) in enumerate(
        tqdm(
            (
                (r"IWAE($\alpha=0.0$)", partial(RenyiELBO, alpha=0.0)),
                (r"Hellinger($\alpha=.5$)", partial(RenyiELBO, alpha=0.5)),
                (r"ELBO ($\alpha=1$)", Trace_ELBO),
                (r"$\alpha=2$", partial(RenyiELBO, alpha=2.0)),
                (r"$\alpha=10$", partial(RenyiELBO, alpha=10.0)),
            )
        )
    ):
        snrs = []
        for num_part in draw_budget:
            rng_key, init_key, snr_key = random.split(snr_key, 3)

            stein = SteinVI(
                model_global,
                guide_global,
                SGD(step_size=1.0),
                AvgELBO(loss(num_particles=num_part)),
                RBFKernel(),
                num_particles=2,
            )
            stein.init(init_key, data)

            def stein_grad(rng_key):
                loss_val, grads = stein._svgd_loss_and_grads(
                    rng_key,
                    params,
                    data,
                    **stein.static_kwargs,
                )
                return grads

            @jit
            def jitted_map(keys):
                return lax.map(stein_grad, keys)

            grads = jitted_map(random.split(rng_key, num_draws))

            snrs.append(snr(grads["bias"][:, 0, :]))  # look at single particle

        snrs = jnp.array(snrs)
        snr_means = snrs.mean(1)
        snr_stds = snrs.std(1)
        ax.set_yscale("log")
        ax.set_xscale("log")
        ax.set_xlabel("$KM$")
        ax.set_ylabel(r"SNR$(\phi_1)$")

        color = next(trace_palette)

        ax.plot(draw_budget, snr_means, color=color, linewidth=3, label=name)

        ax.fill_between(
            draw_budget,
            snr_means - snr_stds,
            snr_means + snr_stds,
            color=color,
            alpha=0.2,
        )
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    ax.legend(prop={"size": 15})
    plt.tight_layout(pad=0)
    plt.savefig("snr_global.png")


def run_dynamics():
    from jax.config import config

    config.update("jax_debug_nans", True)

    global snr_key, data_key, latent_dim
    assert latent_dim == 2

    with seed(rng_seed=data_key), substitute(data={"latent": true_loc}), trace() as tr:
        model_global(None)

    data = tr["obs"]["value"]
    opt_loc = data.mean(0)
    stein = SteinVIForces(
        model_global,
        guide_global,
        SGD(step_size=1.0),
        AvgELBO(Trace_ELBO(num_particles=10_000)),
        RBFKernel(),
        num_particles=2,
    )
    rng_key, init_key, snr_key = random.split(snr_key, 3)

    # https://www2.bcs.rochester.edu/sites/jacobslab/cheat_sheet/bayes_Normal_Normal.pdf
    post_loc = num_data_points / (num_data_points + 1) * opt_loc + true_loc / (
        num_data_points + 1
    )
    post_var = jnp.sqrt(1 / (num_data_points + 1))

    posterior_log_prob = Normal(post_loc, post_var).to_event(1).log_prob

    res = stein.run(init_key, 32, data)

    def stein_grad(rng_key):
        attrac_grads, rep_grads = stein._svgd_forces(
            rng_key,
            res.params,
            data,
            **stein.static_kwargs,
        )
        return attrac_grads, rep_grads

    @jit
    def jitted_map(keys):
        return lax.map(stein_grad, keys)

    attrac_grads, rep_grads = jitted_map(random.split(rng_key, 100))
    attrac_grads, rep_grads = attrac_grads["bias"].mean(0), rep_grads["bias"].mean(0)

    boundary_box = np.array([-1, 1.0])
    x = jnp.linspace(*(post_loc[0] + boundary_box), 1000)
    y = jnp.linspace(*(post_loc[1] + boundary_box), 1000)
    x, y = np.meshgrid(x, y)
    fig, ax = plt.subplots(figsize=(5, 5))
    zi = jnp.exp(posterior_log_prob(jnp.stack((x, y), axis=-1)))
    zi_masked = np.where(zi < 10.0**min_exp, min_drawn_value, zi)
    unit = 1.0
    ax.contourf(x, y, zi_masked, norm=LogNorm())
    ax.quiver(
        *res.params["bias"][0],
        *attrac_grads[0],
        scale=unit,
        color="b",
        # label=r"$S^{+H}_{\Phi}$",
    )
    ax.quiver(
        *res.params["bias"][1],
        *attrac_grads[1],
        scale=unit,
        color="b",
    )

    ax.quiver(
        *res.params["bias"][0],
        *rep_grads[0],
        scale=unit,
        color="r",
        # label=r"$S^-_{\Phi}$",
    )
    ax.quiver(
        *res.params["bias"][1],
        *rep_grads[1],
        scale=unit,
        color="r",
    )

    ax.quiver(
        *res.params["bias"][0],
        *(rep_grads[0] + attrac_grads[0]),
        scale=unit,
        color="k",
        label=r"$S^H_\Phi$",
    )
    ax.quiver(
        *res.params["bias"][1],
        *(rep_grads[1] + attrac_grads[1]),
        scale=unit,
        color="k",
    )

    ax.add_artist(
        plt.Circle(
            post_loc,
            jnp.linalg.norm(post_loc[None, ...] - res.params["bias"], axis=-1).mean(0),
            color="k",
            linestyle="--",
            fill=False,
            linewidth=2,
        )
    )

    ax.scatter(*res.params["bias"][0], s=100, label=r"$\phi_1$")
    ax.scatter(*res.params["bias"][1], s=100, label=r"$\phi_2$")
    ax.scatter(*true_loc, marker="X", color="black", s=75)  # , label=r"$\mu$")
    ax.scatter(
        *post_loc,
        marker="X",
        color="blue",
        s=75,
        # label=r"$\frac{\mu + n \overline{D}}{n+1}$",
    )
    ax.set_aspect("equal", adjustable="box")
    fig.patch.set_visible(False)
    ax.axis("off")
    plt.tight_layout(pad=0.0)
    ax.legend()
    fig.savefig("snr_config.png")


if __name__ == "__main__":
    numpyro.set_platform("cpu")
    run_dynamics()
    # estimate_snr()
