import jax

import jax.numpy as jnp
import jax.scipy as jsp

from functools import partial
from tensorflow_probability.substrates import jax as tfp
from tensorflow_probability.substrates.jax import distributions as tfd


@partial(jax.jit, static_argnums=(1,2,8,9))
def neg_gelbo_objective(
    params,
    model,
    prior,
    x,
    y,
    x_context,
    key,
    ll_scale,
    n_samples, 
    kl_gamma
):
    """
    Generalized functional ELBO objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - prior (Prior): prior distribution on context points
    - x (jnp.ndarray): input data used to calculate the expected log likelihood in the ELBO.
    - y (jnp.ndarray): targets used to calculate the expected log likelihoodin the ELBO.
    - x_context (jnp.ndarray): context points sampled from feature distribution.
    - key (jax.random.PRNGKey): random key.
    - ll_scale (float): scale of the likelihood. 
    - n_samples (int): total number of training samples.
    - kl_gamma (float): KL divergence regularization coefficient.

    returns:
    - neg_gelbo_objective (float): generalized ELBO.
    - expected_ll (float): expected log likelihood.
    - reg_kl (float): regularized KL divergence.
    """
    key1, key2 = jax.random.split(key)
    
    # Compute KL divergence
    p_mean, p_cov = prior(x=x_context) 
    q_mean, q_cov = model.f_distribution(params, x_context, key1)
    reg_kl = regularized_kl(q_mean, p_mean, q_cov, p_cov, kl_gamma)

    # Compute expected log likelihood
    mean, diag_cov = model.f_diag_distribution(params, x, key2)
    expected_ll = jsp.stats.norm.logpdf(
        y, 
        loc=mean.reshape(-1, 1), 
        scale=ll_scale
    ).sum() 
    expected_ll -= 0.5 * diag_cov.sum() / ll_scale**2
    expected_ll *= n_samples / x.shape[0]

    gelbo = expected_ll - reg_kl

    return (
        -gelbo,
        {"expected_ll": expected_ll, "reg_kl": reg_kl, "gelbo": gelbo}
    )


@partial(jax.jit, static_argnums=(1,6))
def expected_ll_objective(
    params,
    model,
    x,
    y,
    key,
    ll_rho,
    n_samples
):
    """
    Expected log-likelihood objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - x (jnp.ndarray): input data used to calculate the expected log likelihood in the ELBO.
    - y (jnp.ndarray): targets used to calculate the expected log likelihoodin the ELBO.
    - key (jax.random.PRNGKey): random key.
    - ll_rho (float): pre-activated likelihood scale. 
    - n_samples (int): total number of training samples.

    returns:
    - expected_ll (float): expected log likelihood.
    """
    # Compute likelihood scale
    ll_scale = jax.nn.softplus(ll_rho)

    # Compute expected log likelihood")
    mean, diag_cov = model.f_diag_distribution(params, x, key)
    expected_ll = jsp.stats.norm.logpdf(
        y, 
        loc=mean.reshape(-1, 1), 
        scale=ll_scale
    ).sum() 
    expected_ll -= 0.5 * diag_cov.sum() / ll_scale**2
    expected_ll *= n_samples / x.shape[0]

    return -expected_ll, {"expected_ll": expected_ll}


@partial(jax.jit, static_argnums=(4,))
def regularized_kl( 
    q_mean, 
    p_mean, 
    q_cov, 
    p_cov, 
    kl_gamma
):
    """
    Compute the regularized KL(q || p).

    params:
    - q_mean (jnp.ndarray): mean of Gaussian distribution q.
    - p_mean (jnp.ndarray): mean of Gaussian distribution p.
    - q_cov (jnp.ndarray): covariance of Gaussian distribution q, 2-D array.
    - p_cov (jnp.ndarray): covariance of Gaussian distribution p, 2-D array.
    - kl_gamma (float): KL divergence regularization coefficient.

    returns:
    - reg_kl (jnp.ndarray): the regularized KL divergence.
    """
    dims = q_mean.shape[0]

    _q_cov = q_cov + jnp.eye(dims) * kl_gamma * dims
    _p_cov = p_cov + jnp.eye(dims) * kl_gamma * dims

    q = tfp.distributions.MultivariateNormalFullCovariance(
        loc=q_mean.transpose(),
        covariance_matrix=_q_cov,
        validate_args=False,
        allow_nan_stats=True,
    )
    p = tfp.distributions.MultivariateNormalFullCovariance(
        loc=p_mean.transpose(),
        covariance_matrix=_p_cov,
        validate_args=False,
        allow_nan_stats=True,
    )

    reg_kl = tfd.kl_divergence(q, p, allow_nan_stats=False)
    
    return reg_kl

