import jax

import jax.numpy as jnp
import jax.scipy as jsp

from functools import partial
from tensorflow_probability.substrates import jax as tfp
from tensorflow_probability.substrates.jax import distributions as tfd


@partial(jax.jit, static_argnums=(1,6,7,9))
def nelbo_fsvi_objective(
    params,
    model,
    x,
    y,
    x_context,
    key,
    mc_samples,
    is_training,
    ll_scale, 
    n_samples
):
    """
    Functional ELBO objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - x (jnp.array): input data of the current task, used to calculate
        the expected log likelihood term in the ELBO.
    - y (jnp.array): labels, used to
        calculate the expected log likelihood term in the ELBO.
    - x_context (jnp.array): an array with context points sampled 
        from feature distribution.
    - key (jax.random.PRNGKey): random key.
    - mc_samples (int): the number of MC samples to estimate the expected
        log likelihood.
    - ll_scale (float): scale of the likelihood model. 
    - n_samples (int): total number of training samples.

    returns:
    - felbo (float): functional ELBO.
    - expected_ll (float): expected log likelihood.
    - kl (float): KL divergence.
    """
    key1, key2, key3 = jax.random.split(key, num=3)
    
    # Compute KL divergence
    q_mean, q_cov = jax.vmap(
        model.f_distribution, 
        in_axes=(None, 0, None), 
        out_axes=(0,0)
    )(params, x_context, key1)
    p_mean, p_cov = jax.vmap(
        model.prior_f_distribution, 
        in_axes=(None, 0, None), 
        out_axes=(0,0)
    )(jax.lax.stop_gradient(params), x_context, key2)
    kl = jax.vmap(
        kl_divergence, 
        in_axes=(0,0,0,0)
    )(q_mean, p_mean, q_cov, p_cov).max()

    # Compute expected log likelihood
    f_hat = model.predict_f(params, x, key3, mc_samples, is_training, stochastic=True)
    expected_ll = n_samples * jnp.mean(
        jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
        axis=0
    ).mean()

    # Compute functional ELBO
    felbo = expected_ll - kl

    return (
        -felbo,
        {"expected_ll": expected_ll, "kl": kl, "felbo": felbo},
    )


@partial(jax.jit, static_argnums=(1,5,6,8))
def expected_ll_objective(
    params,
    model,
    x,
    y,
    key,
    mc_samples,
    is_training,
    ll_rho, 
    n_samples
):
    """
    Expected log-likelihood objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - x (jnp.array): input data of the current task, used to calculate
        the expected log likelihood term in the ELBO.
    - y (jnp.array): labels, used to
        calculate the expected log likelihood term in the ELBO.
    - key (jax.random.PRNGKey): random key.
    - mc_samples (int): the number of MC samples to estimate the expected
        log likelihood.
    - ll_rho (float): scale of the likelihood model. 
    - n_samples (int): total number of training samples.

    returns:
    - expected_ll (float): expected log likelihood.
    """
    # Compute likelihood scale
    ll_scale = jax.nn.softplus(ll_rho)

    # Compute expected log likelihood
    f_hat = model.predict_f(params, x, key, mc_samples, is_training, stochastic=True)
    expected_ll = n_samples * jnp.mean(
        jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
        axis=0
    ).mean()

    return -expected_ll, {"expected_ll": expected_ll}


@jax.jit
def kl_divergence( 
    q_mean, 
    p_mean, 
    q_cov, 
    p_cov
):
    """
    Return KL(q || p).

    params:
    - q_mean (jnp.array): mean of Gaussian distribution q.
    - p_mean (jnp.array): mean of Gaussian distribution p.
    - q_cov (jnp.array): covariance of Gaussian distribution q, 2-D array.
    - q_cov (jnp.array): covariance of Gaussian distribution p, 2-D array.
    
    returns:
    - kl_div (jnp.array): regularized KL divergence.
    """
    kl_jitter = 1e-10
    dims = q_mean.shape[0]

    _cov_q = q_cov + jnp.eye(dims) * kl_jitter
    _cov_p = p_cov + jnp.eye(dims) * kl_jitter

    q = tfp.distributions.MultivariateNormalFullCovariance(
        loc=q_mean.transpose(),
        covariance_matrix=_cov_q,
        validate_args=False,
        allow_nan_stats=True,
    )
    p = tfp.distributions.MultivariateNormalFullCovariance(
        loc=p_mean.transpose(),
        covariance_matrix=_cov_p,
        validate_args=False,
        allow_nan_stats=True,
    )
    kl = tfd.kl_divergence(q, p, allow_nan_stats=False)
    
    return kl
