import jax 

import jax.scipy as jsp
import jax.numpy as jnp

from functools import partial
from jax.example_libraries.optimizers import adam 

from models.Laplace.training_utils.objective import neg_log_posterior_objective, neg_log_likelihood_objective


def fit_model(
    key, 
    params, 
    model, 
    config, 
    train_dataloader,
    val_dataloader
):
    """
    Fit the model.

    params:
    - key (jax.random.PRNGKey): random key.
    - params (jax.tree_util.pytree): parameters of the neural network.
    - model (Model): neural network.
    - config (dict): configuration dictionary.
    - train_dataloader (DataLoader): wrapper for the training data. 
    - val_dataloader (DataLoader): wrapper for the validation data.

    returns:
    - params (jax.tree_util.pytree): updated parameters.
    - val_loss (dict): validation loss.
    """
    # Read configuration
    nb_epochs = config["laplace"]["training"]["nb_epochs"]
    prior_scale = config["laplace"]["prior_scale"]
    early_stopping_patience = config["laplace"]["training"]["patience"]
    max_grad_norm = config["laplace"]["training"]["max_grad_norm"]

    # Initialize optimizer
    nn_opt_init, nn_opt_update, nn_get_params = adam(
        config["laplace"]["training"]["lr"],
        config["laplace"]["training"]["b1"],
        config["laplace"]["training"]["b2"],
        config["laplace"]["training"]["eps"]
    )
    nn_opt_state = nn_opt_init(params)

    # Initialize likelihood rho optimizer
    ll_scale = config["laplace"]["likelihood_scale"]
    ll_rho = jnp.log(jnp.exp(ll_scale)-1)
    ll_opt_init, ll_opt_update, ll_get_params = adam(
        config["laplace"]["training"]["lr"],
        config["laplace"]["training"]["b1"],
        config["laplace"]["training"]["b2"],
        config["laplace"]["training"]["eps"]
    )
    ll_opt_state = ll_opt_init(ll_rho)

    # Number of training and validation samples
    n_train_samples = len(train_dataloader.dataset)
    n_val_samples = len(val_dataloader.dataset)

    # Early stopping initialization
    opt_log_posterior, no_improve_count = jnp.NINF, 0
    opt_params, opt_ll_scale = None, None
    
    # Training loop 
    print("Previous training steps: ", model.training_steps, flush=True)
    step = model.training_steps
    for epoch in range(nb_epochs):
        log_likelihood, log_posterior = 0., 0.
        for x, y in train_dataloader:
            # Handle keys
            key, key1 = jax.random.split(key)

            # Update the likelihood scale
            ll_rho, ll_opt_state, loss_info = update_ll(
                params, 
                ll_opt_state,
                ll_get_params,
                x,
                y, 
                key, 
                ll_opt_update,
                model,
                ll_rho, 
                step, 
                n_train_samples,
                max_grad_norm
            )
            ll_scale = jax.nn.softplus(ll_rho)

            # Update the model
            params, nn_opt_state, loss_info = update_nn(
                params, 
                nn_opt_state,
                nn_get_params,
                x,
                y, 
                key1, 
                nn_opt_update,
                model,
                ll_scale, 
                prior_scale,
                step, 
                n_train_samples,
                max_grad_norm
            )
            log_likelihood += loss_info["log_likelihood"]
            log_posterior += loss_info["log_posterior"]
            step += 1
                
        # Evaluation
        if epoch % 100 == 0 or epoch == nb_epochs-1:
            log_likelihood, log_posterior = 0., 0.
            for x, y in val_dataloader:
                # Handle keys
                key, key1 = jax.random.split(key)
                # prediction
                val_loss, val_info = neg_log_posterior_objective(
                    params,
                    model,
                    x,
                    y,
                    key1,
                    ll_scale, 
                    prior_scale, 
                    n_val_samples
                )
                log_likelihood += val_info["log_likelihood"]
                log_posterior += val_info["log_posterior"]

            # Log validation loss 
            print(f"Epoch {epoch} - val log_posterior: {log_posterior} - val log_likelihood {log_likelihood}", flush=True)

            # Early stopping
            if log_posterior > opt_log_posterior:
                opt_log_posterior = log_posterior
                opt_params = params
                opt_ll_scale = ll_scale
                no_improve_count = 0
            else:
                no_improve_count += 100
                if no_improve_count >= early_stopping_patience:
                    params = opt_params
                    ll_scale = opt_ll_scale
                    print("Early stopping.", flush=True)
                    break
    model.training_steps = step
    print("Likelihood scale:", ll_scale, flush=True)

    return params, ll_scale, {"log_likelihood": log_likelihood, "log_posterior": log_posterior}


@partial(jax.jit, static_argnums=(2,6,7,9,11,12))
def update_nn(
    params, 
    nn_opt_state,
    nn_get_params,
    x_batch,
    y_batch, 
    key, 
    nn_opt_update,
    model,
    ll_scale, 
    prior_scale,
    step, 
    n_samples,
    max_grad_norm
):
    """
    Gradient update step.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - nn_opt_state (jax.tree_util.pytree): optimizer state.
    - nn_get_params (jax.tree_util.pytree): function to get parameters.
    - x_batch (jax.numpy.ndarray): a batch of input images.
    - y_batch (jax.numpy.ndarray): a batch of labels.
    - key (jax.random.PRNGKey): JAX random seed.
    - nn_opt_update (callable): optimizer update function.
    - model (Model): NN model.
    - ll_scale (float): scale of the likelihood model. 
    - prior_scale (float): scale of the prior.
    - step (int): current step.
    - n_samples (int): number of training samples.
    - max_grad_norm (float): maximum gradient norm.

    returns:
    - params (jax.tree_util.pytree): updated parameters.
    - nn_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    grads, other_info = jax.grad(neg_log_posterior_objective, has_aux=True)(
        params,
        model,
        x_batch,
        y_batch,
        key,
        ll_scale, 
        prior_scale, 
        n_samples
    )
    if max_grad_norm is not None:
        norm = jax.example_libraries.optimizers.l2_norm(grads) 
        clip = lambda x: jnp.where(norm < max_grad_norm, x, x * max_grad_norm / (norm + 1e-6))
        grads = jax.tree_util.tree_map(clip, grads)
    
    nn_opt_state = nn_opt_update(step, grads, nn_opt_state)
    params = nn_get_params(nn_opt_state)

    return params, nn_opt_state, other_info


@partial(jax.jit, static_argnums=(2,6,7,10,11))
def update_ll(
    params, 
    ll_opt_state,
    ll_get_params,
    x_batch,
    y_batch, 
    key, 
    ll_opt_update,
    model,
    ll_rho, 
    step, 
    n_samples,
    max_grad_norm
):
    """
    Gradient update step.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - ll_opt_state (jax.tree_util.pytree): optimizer state.
    - ll_get_params (jax.tree_util.pytree): function to get parameters.
    - x_batch (jax.numpy.ndarray): a batch of input images.
    - y_batch (jax.numpy.ndarray): a batch of labels.
    - key (jax.random.PRNGKey): JAX random seed.
    - ll_opt_update (callable): optimizer update function.
    - model (Model): NN model.
    - ll_rho (float): pre-activated scale of the likelihood model. 
    - step (int): current step.
    - n_samples (int): number of training samples.
    - max_grad_norm (float): maximum gradient norm.

    returns:
    - params (jax.tree_util.pytree): updated parameters.
    - nn_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    grads, other_info = jax.grad(neg_log_likelihood_objective, argnums=5, has_aux=True)(
        params,
        model,
        x_batch,
        y_batch,
        key,
        ll_rho,
        n_samples
    )
    if max_grad_norm is not None:
        norm = jax.example_libraries.optimizers.l2_norm(grads) 
        clip = lambda x: jnp.where(norm < max_grad_norm, x, x * max_grad_norm / (norm + 1e-6))
        grads = jax.tree_util.tree_map(clip, grads)

    ll_opt_state = ll_opt_update(step, grads, ll_opt_state)
    params = ll_get_params(ll_opt_state)

    return params, ll_opt_state, other_info


def evaluate_model(
    key, 
    model, 
    dataloader
):
    """
    Evaluate the model on the test set.

    params:
    - key (jax.random.PRNGKey): JAX random seed.
    - model (Model): BNN model.
    - test_dataloader (DataLoader): data loader.

    returns:
    - test_loss (dict): test loss.
    """
    assert dataloader.replacement == False, "Data should be sampled without replacement"
    
    # Get likelihood scale
    ll_scale = model.ll_scale

    # Load test data    
    expected_ll, mse = 0., 0.
    for x, y in dataloader:
        # Handle keys
        key, key1 = jax.random.split(key)
        # prediction
        mean, diag_cov = model.f_distribution_mean_var(x, key1, mc_samples=None)
        expected_ll += jsp.stats.norm.logpdf(
            y, 
            loc=mean.reshape(-1, 1), 
            scale=ll_scale
        ).sum() 
        expected_ll -= 0.5 * diag_cov.sum() / ll_scale**2
        mse += jnp.sum((mean.reshape(-1) - y.reshape(-1))**2)
    mse /= len(dataloader.dataset)
    expected_ll /= len(dataloader.dataset)
        
    print(f"Expected log-likelihood: {expected_ll} - MSE: {mse}", flush=True)

    return {"expected_ll": expected_ll, "mse": mse}