import jax
import haiku as hk
import jax.numpy as jnp


jax.config.update("jax_enable_x64", True)


def laplace_inference(
    model, 
    mean_params, 
    dataloader, 
    key, 
    config
):
    """
    Compute covariance under Laplace approximation.

    params:
    - model (Model): neural network.
    - mean_params (jax.tree_util.pytree): parameters of the neural network.
    - dataloader (DataLoader): wrapper for the data.
    - key (jax.random.PRNGKey): random key.
    - config (dict): configuration dictionary.

    returns:
    - cov (jnp.array): covariance under Laplace approximation.
    """
    # Get configuration
    stochastic_layers = config["laplace"]["stochastic_layers"]
    prior_scale = config["laplace"]["prior_scale"]
    cov_type = config["laplace"]["cov_type"]

    # Get likelihood scale
    ll_scale = model.ll_scale
    
    # Split parameters 
    stochastic_params, static_params = split_parameters(mean_params, stochastic_layers)
    dim = hk.data_structures.tree_size(stochastic_params)

    # Contribution of the prior to the precision 
    precision = jnp.ones((dim)) / prior_scale**2
    precision = jnp.diag(precision) if cov_type == 'full' else precision 
    
    # Compute covariance matrix
    for x, y in dataloader:
        # Split the keys
        key, key1 = jax.random.split(key)
        
        # Hessian of the likelihood
        H = likelihood_hessian(x, ll_scale)

        fwd = lambda p: model.apply_fn(join_parameters(p, static_params), key1, x)
        _vjp = jax.vjp(fwd, stochastic_params)[1]

        # Hessian jacobian product
        H_J = jax.vmap(_vjp)(jnp.expand_dims(H, axis=-1))[0]
        leaves = jax.tree_util.tree_flatten(H_J)[0]
        Jt_H = jnp.concatenate([i.reshape(x.shape[0], -1) for i in leaves], axis=-1).T

        # Jacobian Tr Hessian Jacobian product
        Jt_H_J = jax.vmap(_vjp)(jnp.expand_dims(Jt_H, axis=-1))
        leaves = jax.tree_util.tree_flatten(Jt_H_J)[0]
        Jt_H_J = jnp.concatenate([i.reshape(dim, -1) for i in leaves], axis=-1)
        Jt_H_J = jnp.diagonal(Jt_H_J) if cov_type == 'diag' else Jt_H_J 

        # Update precision
        precision -= Jt_H_J
                
    # Compute covariance
    if cov_type == "full":
        cov = jnp.linalg.inv(precision)
    elif cov_type == "diag":
        cov = 1 / precision

    return cov


def likelihood_hessian(
    x, 
    ll_scale
):
    """
    Compute hessian of likelihood with respect to its parameters.

    params:
    - x (jnp.array): input data.
    - ll_scale (float): log-likelihood scale. 

    returns:
    - hessian (jnp.array): hessian of likelihood.
    """
    hessian = -1 / ll_scale**2 * jnp.eye(x.shape[0])
    
    return hessian


def split_parameters(
    mean_params, 
    stochastic_layers
):
    """
    Split parameters into stochastic and non-stochastic.
    
    params:
    - mean_params (jax.tree_util.pytree): parameters of the BNN.
    - stochastic_layers (list[Bool]): list indicating which layers are stochastic.

    returns:
    - stochastic_params (jax.tree_util.pytree): stochastic parameters of the BNN.
    - static_params (jax.tree_util.pytree): static parameters of the BNN.
    """
    stochastic_params, static_params = hk.data_structures.partition(
        lambda m, n, p: stochastic_layers[int(m[23:]) if m[23:] else 0], mean_params
    )

    return stochastic_params, static_params
    

def join_parameters(
    stochastic_params, 
    static_params
):
    """
    Join stochastic and non-stochastic parameters.
    
    params:
    - stochastic_params (jax.tree_util.pytree): stochastic parameters.
    - static_params (jax.tree_util.pytree): static parameters.

    returns:
    - params (jax.tree_util.pytree): parameters of the BNN.
    """
    return hk.data_structures.merge(stochastic_params, static_params)
