import jax 

import jax.scipy as jsp
import jax.numpy as jnp

from functools import partial

from jax.example_libraries.optimizers import adam 

from models.MFVI.training_utils.objective import nelbo_objective, expected_ll_objective


def fit_model(
    key, 
    params, 
    model, 
    config, 
    train_dataloader,
    val_dataloader
):
    """
    Fit the model.

    params:
    - key (jax.random.PRNGKey): random key.
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): BNN.
    - config (dict): configuration dictionary.
    - train_dataloader (DataLoader): data loader for training data.
    - val_dataloader (DataLoader): data loader for validation data.

    returns:
    - params (jax.tree_util.pytree): parameters of the BNN.
    """
    # Read configuration
    mc_samples = config["mfvi"]["training"]["mc_samples"]
    nb_epochs = config["mfvi"]["training"]["nb_epochs"]
    early_stopping_patience = config["mfvi"]["training"]["patience"]
    max_grad_norm = config["mfvi"]["training"]["max_grad_norm"]

    # Initialize parameter optimizer
    nn_opt_init, nn_opt_update, nn_get_params = adam(
        config["mfvi"]["training"]["lr"],
        config["mfvi"]["training"]["b1"],
        config["mfvi"]["training"]["b2"],
        config["mfvi"]["training"]["eps"]
    )
    nn_opt_state = nn_opt_init(params)

    # Initialize likelihood rho optimizer
    ll_scale = config["mfvi"]["likelihood_scale"]
    ll_rho = jnp.log(jnp.exp(ll_scale)-1)
    ll_opt_init, ll_opt_update, ll_get_params = adam(
        config["mfvi"]["training"]["lr"],
        config["mfvi"]["training"]["b1"],
        config["mfvi"]["training"]["b2"],
        config["mfvi"]["training"]["eps"]
    )
    ll_opt_state = ll_opt_init(ll_rho)

    # Number of training and validation samples
    n_train_samples = len(train_dataloader.dataset)
    n_val_samples = len(val_dataloader.dataset)

    # Early stopping initialization
    opt_elbo, no_improve_count = jnp.NINF, 0
    opt_params, opt_ll_scale = None, None
    
    # Training loop 
    print("Previous training steps:", model.training_steps, flush=True)
    step = model.training_steps
    for epoch in range(nb_epochs):
        train_expected_ll, train_elbo, train_kl = 0., 0., 0.
        for x, y in train_dataloader:
            # Handle keys
            key, key1, key2 = jax.random.split(key, num=3)

            # Update likelihood rho
            ll_rho, ll_opt_state, loss_info = update_ll(
                params, 
                ll_opt_state,
                ll_get_params, 
                x, 
                y, 
                key2, 
                ll_opt_update, 
                mc_samples,
                model, 
                ll_rho, 
                n_train_samples, 
                step,
                max_grad_norm
            )
            ll_scale = jax.nn.softplus(ll_rho)

            # Update the model
            params, nn_opt_state, loss_info = update_nn(
                params, 
                nn_opt_state,
                nn_get_params, 
                x, 
                y, 
                key2, 
                nn_opt_update, 
                mc_samples, 
                model, 
                ll_scale, 
                n_train_samples, 
                step, 
                max_grad_norm
            )
            train_expected_ll += loss_info["expected_ll"]
            train_elbo += loss_info["elbo"]
            train_kl += loss_info["kl"]
            step += 1

        # Evaluation
        if epoch % 100 == 0 or epoch == nb_epochs-1:
            expected_ll, elbo, kl = 0., 0., 0.
            for x, y in val_dataloader:
                # Handle keys
                key, key1 = jax.random.split(key)

                # prediction
                val_loss, val_info = nelbo_objective(
                    params, 
                    model, 
                    x, 
                    y, 
                    key1, 
                    mc_samples, 
                    False, # is_training
                    ll_scale, 
                    n_val_samples
                )
                expected_ll += val_info["expected_ll"]
                elbo += val_info["elbo"]
                kl += val_info["kl"]

            # Log validation loss 
            print(f"Epoch {epoch} - val elbo: {elbo} - val expected_ll {expected_ll} - KL {kl}", flush=True)

            # Early stopping
            if elbo > opt_elbo:
                opt_elbo = elbo
                opt_params = params
                opt_ll_scale = ll_scale
                no_improve_count = 0
            else:
                no_improve_count += 100
                if no_improve_count >= early_stopping_patience:
                    params = opt_params
                    ll_scale = opt_ll_scale
                    print("Early stopping.", flush=True)
                    break
    model.training_steps = step
    print("Likehood scale:", ll_scale, flush=True)

    return params, ll_scale, {"elbo": elbo, "expected_ll": expected_ll, "kl": kl}


@partial(jax.jit, static_argnums=(2,6,7,8,10,12))
def update_nn(
    params, 
    nn_opt_state,
    nn_get_params,
    x_batch,
    y_batch, 
    key, 
    nn_opt_update,
    mc_samples,
    model,
    ll_scale, 
    n_samples,
    step, 
    max_grad_norm
):
    """
    Gradient update step.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - nn_opt_state (jax.tree_util.pytree): optimizer state.
    - nn_get_params (jax.tree_util.pytree): function to get parameters.
    - x_batch (jnp.array): a batch of input images.
    - y_batch (jnp.array): a batch of labels.
    - key (jax.random.PRNGKey): JAX random seed.
    - nn_opt_update (callable): optimizer update function.
    - mc_samples (int): number of Monte-Carlo samples for estimating the expected
        log likelihood.
    - model (Model): BNN model.
    - ll_scale (float): scale of the likelihood model. 
    - n_samples (int): total number of training samples
    - step (int): current step.
    - max_grad_norm (float): maximum gradient norm.

    returns:
    - params (jax.tree_util.pytree): updated parameters.
    - nn_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    grads, other_info = jax.grad(nelbo_objective, has_aux=True)(
        params,
        model, 
        x_batch,
        y_batch,
        key,
        mc_samples,
        True,  # is_training
        ll_scale, 
        n_samples
    )
    if max_grad_norm is not None:
        norm = jax.example_libraries.optimizers.l2_norm(grads) 
        clip = lambda x: jnp.where(norm < max_grad_norm, x, x * max_grad_norm / (norm + 1e-6))
        grads = jax.tree_util.tree_map(clip, grads)
    
    nn_opt_state = nn_opt_update(step, grads, nn_opt_state)
    params = nn_get_params(nn_opt_state)

    return params, nn_opt_state, other_info


@partial(jax.jit, static_argnums=(2,6,7,8,10,12))
def update_ll(
    params, 
    ll_opt_state,
    ll_get_params, 
    x, 
    y, 
    key, 
    ll_opt_update, 
    mc_samples,
    model, 
    ll_rho, 
    n_samples, 
    step,
    max_grad_norm
):
    """
    Gradient update step on model likelihood rho.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - ll_opt_state (jax.tree_util.pytree): optimizer state.
    - ll_get_params (jax.tree_util.pytree): function to get parameters.
    - x (jnp.array): a batch of input.
    - y (jnp.array): a batch of labels.
    - key (jax.random.PRNGKey): JAX random seed.
    - ll_opt_update (callable): optimizer update_nn function.
    - mc_samples (int): number of Monte-Carlo samples for estimating the objective.
    - model (Model): BNN model.
    - ll_rho (float): pre-activated likelihood scale. 
    - n_samples (int): total number of training samples
    - step (int): current step.
    - max_grad_norm (float): maximum gradient norm.

    returns:
    - ll_rho (jax.tree_util.pytree): updated parameters.
    - ll_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    grads, other_info = jax.grad(expected_ll_objective, argnums=7, has_aux=True)(
        params,
        model, 
        x,
        y,
        key,
        mc_samples,
        True, # is_training
        ll_rho, # differentiate wrt ll_rho
        n_samples
    )
    if max_grad_norm is not None:
        norm = jax.example_libraries.optimizers.l2_norm(grads) 
        clip = lambda x: jnp.where(norm < max_grad_norm, x, x * max_grad_norm / (norm + 1e-6))
        grads = jax.tree_util.tree_map(clip, grads)
        
    ll_opt_state = ll_opt_update(step, grads, ll_opt_state)
    ll_rho = ll_get_params(ll_opt_state)

    return ll_rho, ll_opt_state, other_info 


def evaluate_model(
    key, 
    params, 
    model, 
    dataloader
):
    """
    Evaluate the model on the test set.

    params:
    - key (jax.random.PRNGKey): JAX random seed.
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): BNN model.
    - dataloader (DataLoader): data loader.
    - config (dict): configuration.

    returns:
    - test_loss (dict): test loss.
    """
    assert dataloader.replacement == False, "Data should be sampled without replacement"

    # Read configuration
    mc_samples = 100 

    # Get likelihood scale
    ll_scale = model.ll_scale

    # Load test data    
    expected_ll, mse = 0., 0.
    for x, y in dataloader:
        # Handle keys
        key, key1 = jax.random.split(key)

        # Prediction
        f_hat, kl_div = model.predict_f(params, x, key1, mc_samples, is_training=False, stochastic=True)
        expected_ll += jnp.mean(
            jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
            axis=0
        ).sum()
        mse += jnp.sum((f_hat.mean(0).reshape(-1) - y.reshape(-1))**2)

    mse /= len(dataloader.dataset)
    expected_ll /= len(dataloader.dataset)

    print(f"Expected log-likelihood: {expected_ll} - MSE: {mse}", flush=True)

    return {"expected_ll": expected_ll, "mse": mse}