import jax

import jax.numpy as jnp
import jax.scipy as jsp

from functools import partial

from models.FVI.training_utils.utils import SpectralScoreEstimator


@partial(jax.jit, static_argnums=(1,2,8))
def neg_felbo_objective(
    params,
    model,
    prior,
    x,
    y,
    x_context,
    key,
    ll_scale,
    mc_samples
):
    """
    Functional ELBO objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - prior (Prior): prior distribution on context points
    - x (jnp.ndarray): input data used to calculate the expected log likelihood in the ELBO.
    - y (jnp.ndarray): targets used to calculate the expected log likelihoodin the ELBO.
    - x_context (jnp.ndarray): context points sampled from feature distribution.
    - key (jax.random.PRNGKey): random key.
    - ll_scale (float): scale of the likelihood. 
    - mc_samples (int): number of Monte Carlo samples.

    returns:
    - neg_felbo_objective (float): functional ELBO.
    - expected_ll (float): expected log likelihood.
    - kl_div (float): KL divergence.
    """
    key1, key2, key3 = jax.random.split(key, 3)

    # Jitter for numerical stability
    jitter=0.01
    
    # Compute KL divergence
    delta = 500 - x_context.shape[0]
    x_context = jnp.concatenate([x_context, x[:delta,:]], axis=0) # such that we have 500 context points
    f_context = model.predict_f(params, x_context, key1, mc_samples=100).reshape(100, -1) # default mc_samples=100
    f_context += jitter * jax.random.normal(key2, f_context.shape)
    kl_div = kl_divergence(x_context, f_context, prior) / x.shape[0]

    # Compute expected log likelihood
    f_hat = model.predict_f(params, x, key3, mc_samples)
    expected_ll = jnp.mean(
        jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
        axis=0
    ).mean()

    felbo = expected_ll - kl_div

    return (
        -felbo,
        {"expected_ll": expected_ll, "kl_div": kl_div, "felbo": felbo}
    )


@partial(jax.jit, static_argnums=(2,3,4,5))
def kl_divergence(
    x_context,
    f_context, 
    prior, 
    eta=0., 
    n_eigen_threshold=0.99, 
    jitter=0.01
):
    # Estimate entropy surrogate
    estimator = SpectralScoreEstimator(eta=eta, n_eigen_threshold=n_eigen_threshold)
    dlog_q = estimator.compute_gradients(f_context)
    entropy_sur = jnp.mean(
        jnp.sum(jax.lax.stop_gradient(-dlog_q) * f_context, -1)
    )

    # Compute cross entropy
    prior_mean, prior_cov = prior(x_context)
    prior_cov += jitter**2 * jnp.eye(prior_mean.shape[-1])
    cross_entropy = -jax.scipy.stats.multivariate_normal.logpdf(
        f_context, 
        prior_mean.reshape(1, -1), 
        prior_cov.reshape(1, prior_mean.shape[-1], -1)
    ).mean()

    # Compute KL 
    KL_div = -entropy_sur + cross_entropy

    return KL_div


@partial(jax.jit, static_argnums=(1,6))
def expected_ll_objective(
    params,
    model,
    x,
    y,
    key,
    ll_rho,
    mc_samples
):
    """
    Expected log-likelihood objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - x (jnp.ndarray): input data used to calculate the expected log likelihood in the ELBO.
    - y (jnp.ndarray): targets used to calculate the expected log likelihoodin the ELBO.
    - key (jax.random.PRNGKey): random key.
    - ll_rho (float): pre-activated likelihood scale. 
    - mc_samples (int): number of Monte Carlo samples.

    returns:
    - expected_ll (float): expected log likelihood.
    """
    # Compute likelihood scale
    ll_scale = jax.nn.softplus(ll_rho)

    # Compute expected log likelihood
    f_hat = model.predict_f(params, x, key, mc_samples)
    expected_ll = jnp.mean(
        jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
        axis=0
    ).mean()

    return -expected_ll, {"expected_ll": expected_ll}
