from jax import numpy as jnp
from jax.nn import sigmoid, tanh
from jax.nn.initializers import glorot_normal

from numpyro import param, plate, sample
from numpyro.distributions import Bernoulli, Normal


def glorot_normal_initializer(*shape):
    return lambda rng_key: glorot_normal(in_axis=-1, out_axis=-2)(rng_key, shape)


def normal_encoder(x, params):
    l1 = tanh(jnp.matmul(x, params["l1"]) + params["l1_b"])
    l2 = tanh(jnp.matmul(l1, params["l2"]) + params["l2_b"])

    mean = jnp.matmul(l2, params["mean"]) + params["mean_b"]
    std = jnp.exp(jnp.matmul(l2, params["std"]) + params["std_b"])
    return mean, std


def bernoulli_encoder(x, params):
    l1 = tanh(jnp.matmul(x, params["l1"]) + params["l1_b"])
    l2 = tanh(jnp.matmul(l1, params["l2"]) + params["l2_b"])

    probs = sigmoid(jnp.matmul(l2, params["mean"]) + params["mean_b"])
    return probs


def add_ne_params(in_dim, hidden_dim, out_dim, name=""):
    params = {
        "l1": param(f"{name}_l1", glorot_normal_initializer(in_dim, hidden_dim)),
        "l2": param(f"{name}_l2", glorot_normal_initializer(hidden_dim, hidden_dim)),
        "mean": param(f"{name}_mean", glorot_normal_initializer(hidden_dim, out_dim)),
        "std": param(f"{name}_std", glorot_normal_initializer(hidden_dim, out_dim)),
        "l1_b": param(f"{name}_l1_b", jnp.zeros(hidden_dim)),
        "l2_b": param(f"{name}_l2_b", jnp.zeros(hidden_dim)),
        "mean_b": param(f"{name}_mean_b", jnp.zeros(out_dim)),
        "std_b": param(f"{name}_std_b", jnp.zeros(out_dim)),
    }
    return params


def add_be_params(in_dim, hidden_dim, out_dim, name=""):
    params = {
        "l1": param(f"{name}_l1", glorot_normal_initializer(in_dim, hidden_dim)),
        "l2": param(f"{name}_l2", glorot_normal_initializer(hidden_dim, hidden_dim)),
        "mean": param(f"{name}_mean", glorot_normal_initializer(hidden_dim, out_dim)),
        "l1_b": param(f"{name}_l1_b", jnp.zeros(hidden_dim)),
        "l2_b": param(f"{name}_l2_b", jnp.zeros(hidden_dim)),
        "mean_b": param(f"{name}_mean_b", jnp.zeros(out_dim)),
    }
    return params


def model(
    x,
    subsample_size=2000,
    h1_hidden_dim=200,
    h1_latent_dim=100,
    h2_hidden_dim=100,
    h2_latent_dim=50,
    num_data=1,
    img_dim=784,
):
    h1_params = add_ne_params(
        h2_latent_dim, h2_hidden_dim, h1_latent_dim, name="h1_model"
    )
    out_params = add_be_params(h1_latent_dim, h2_hidden_dim, img_dim, name="out")
    with plate("data", num_data, subsample_size=subsample_size, dim=-1) as idx:
        if x is not None:
            xbatch = x[idx]
        h2 = sample("h2", Normal(0, 1).expand((h2_latent_dim,)).to_event(1))
        h1 = sample("h1", Normal(*normal_encoder(h2, h1_params)).to_event(1))
        sample(
            "img",
            Bernoulli(probs=bernoulli_encoder(h1, out_params)).to_event(1),
            obs=xbatch,
        )


def guide(
    x,
    subsample_size=2000,
    h1_hidden_dim=200,
    h1_latent_dim=100,
    h2_hidden_dim=100,
    h2_latent_dim=50,
    num_data=1,
    img_dim=784,
):
    h1_params = add_ne_params(
        x.shape[-1], h1_hidden_dim, h1_latent_dim, name="h1_guide"
    )
    h2_params = add_ne_params(
        h1_latent_dim, h2_hidden_dim, h2_latent_dim, name="h2_guide"
    )
    with plate("data", x.shape[0], subsample_size=subsample_size, dim=-1) as idx:
        xbatch = x[idx]
        h1 = sample("h1", Normal(*normal_encoder(xbatch, h1_params)).to_event(1))
        sample("h2", Normal(*normal_encoder(h1, h2_params)).to_event(1))
