import jax 

import jax.scipy as jsp

from functools import partial


@partial(jax.jit, static_argnums=(1,6,7))
def neg_log_posterior_objective(
    params,
    model,
    x,
    y,
    key,
    ll_scale, 
    prior_scale, 
    n_samples
):
    """
    Negative log-posterior objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the mlp.
    - model (Model): neural network.
    - x (jax.numpy.ndarray): feature batch.
    - y (jax.numpy.ndarray): label batch.
    - key (jax.random.PRNGKey): random key.
    - ll_scale (float): scale of the likelihood model. 
    - prior_scale (float): scale of the prior.
    - n_samples (int): number of training samples.

    returns:
    - neg_log_posterior (float): negative log-posterior.
    - other_info (dict): other information.
    """
    f_hat = model.apply_fn(params, key, x)
    log_likelihood = jsp.stats.norm.logpdf(
        y, 
        loc=f_hat, 
        scale=ll_scale
    ).mean() * n_samples
    
    log_prior = -0.5*jax.example_libraries.optimizers.l2_norm(params)**2 / prior_scale**2
    
    log_posterior = log_likelihood + log_prior

    return (
        -log_posterior,
        {"log_likelihood": log_likelihood , "log_posterior": log_posterior}
    )


@partial(jax.jit, static_argnums=(1,6))
def neg_log_likelihood_objective(
    params,
    model,
    x,
    y,
    key,
    ll_rho, 
    n_samples
):
    """
    Negative log-likelihood objective.
    
    params:
    - params (jax.tree_util.pytree): parameters of the mlp.
    - model (Model): neural network.
    - x (jax.numpy.ndarray): feature batch.
    - y (jax.numpy.ndarray): label batch.
    - key (jax.random.PRNGKey): random key.
    - ll_rho (float): pre-activated scale of the likelihood model. 
    - n_samples (int): number of training samples.

    returns:
    - log_likelihood (float): negative log-likelihood.
    - other_info (dict): other information.
    """
    # Compute likelihood scale
    ll_scale = jax.nn.softplus(ll_rho)

    # Compute log-likelihood
    f_hat = model.apply_fn(params, key, x)
    log_likelihood = jsp.stats.norm.logpdf(
        y, 
        loc=f_hat, 
        scale=ll_scale
    ).mean() * n_samples
    
    return -log_likelihood, {"log_likelihood": log_likelihood}
