import jax

import jax.numpy as jnp
import jax.scipy as jsp

from functools import partial


@partial(jax.jit, static_argnums=(1,5,6,8))
def nelbo_objective(
    params,
    model,
    x,
    y,
    key,
    mc_samples,
    is_training,
    ll_scale, 
    n_samples
):
    """
    ELBO.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - x (jnp.array): input data.
    - y (jnp.array): labels.
    - key (jax.random.PRNGKey): random key.
    - mc_samples (int): the number of MC samples to estimate the expected
        log likelihood.
    - is_training (bool): whether it's in training mode.
    - ll_scale (float): scale of the likelihood model. 
    - n_samples (int): total number of training samples.

    returns:
    - elbo (float): ELBO.
    - expected_ll (float): expected log likelihood.
    - kl (float): KL divergence.
    """
    # Forward pass : (batch, n_outputs)
    f_hat, kl_div = model.predict_f(params, x, key, mc_samples, is_training, stochastic=True)
    expected_ll = n_samples * jnp.mean(
        jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
        axis=0
    ).mean()

    elbo = expected_ll - kl_div

    return (
        -elbo,
        {"expected_ll": expected_ll, "kl": kl_div, "elbo": elbo}
    )


@partial(jax.jit, static_argnums=(1,5,6,8))
def expected_ll_objective(
    params,
    model,
    x,
    y,
    key,
    mc_samples,
    is_training,
    ll_rho, 
    n_samples
):
    """
    Expected log-likelihood objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): stochastic neural networks.
    - x (jnp.array): input data.
    - y (jnp.array): labels.
    - key (jax.random.PRNGKey): random key.
    - mc_samples (int): the number of MC samples to estimate the expected
        log likelihood.
    - is_training (bool): whether it's in training mode.
    - ll_rho (float): pre-activated likelihood scale. 
    - n_samples (int): total number of training samples.

    returns:
    - expected_ll (float): expected log likelihood.
    """
    # Compute likelihood scale
    ll_scale = jax.nn.softplus(ll_rho)

    # Forward pass : (batch, n_outputs)
    f_hat, kl_div = model.predict_f(params, x, key, mc_samples, is_training, stochastic=True)
    expected_ll = n_samples * jnp.mean(
        jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
        axis=0
    ).mean()

    return -expected_ll, {"expected_ll": expected_ll}
