import jax 

import jax.numpy as jnp 
import jax.scipy as jsp

from functools import partial
from jax.example_libraries.optimizers import adam 

from models.TFSVI.training_utils.objective import nelbo_fsvi_objective, expected_ll_objective


def fit_model(
    key, 
    params, 
    model, 
    config, 
    train_dataloader, 
    val_dataloader
):
    """
    Fit the model.

    params:
    - key (jax.random.PRNGKey): random key.
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): BNN.
    - config (dict): configuration dictionary.
    - train_dataloader (DataLoader): dataloader for train data.
    - val_dataloader (DataLoader): dataloader for val data.

    returns: 
    - params (jax.tree_util.pytree): parameters of the BNN.
    - val_loss (dict): validation loss.
    """
    # Read configuration
    nb_epochs = config["tfsvi"]["training"]["nb_epochs"]
    mc_samples = config["tfsvi"]["training"]["mc_samples"]
    n_context_sets = config["tfsvi"]["training"]["n_context_sets"]
    n_context_points = config["tfsvi"]["training"]["n_context_points"]
    early_stopping_patience = config["tfsvi"]["training"]["patience"]
    max_grad_norm = config["tfsvi"]["training"]["max_grad_norm"]

    # Initialize parameter optimizer
    nn_opt_init, nn_opt_update, nn_get_params = adam(
        config["tfsvi"]["training"]["lr"],
        config["tfsvi"]["training"]["b1"],
        config["tfsvi"]["training"]["b2"],
        config["tfsvi"]["training"]["eps"]
    )
    nn_opt_state = nn_opt_init(params)

    # Initialize likelihood rho optimizer
    ll_scale = config["tfsvi"]["likelihood_scale"]
    ll_rho = jnp.log(jnp.exp(ll_scale)-1)
    ll_opt_init, ll_opt_update, ll_get_params = adam(
        config["tfsvi"]["training"]["lr"],
        config["tfsvi"]["training"]["b1"],
        config["tfsvi"]["training"]["b2"],
        config["tfsvi"]["training"]["eps"]
    )
    ll_opt_state = ll_opt_init(ll_rho)

    # Early stopping initialization
    opt_felbo, no_improve_count = jnp.NINF, 0
    opt_params, opt_ll_scale = None, None

    # Number of training and valiation samples
    n_train_samples = len(train_dataloader.dataset)
    n_val_samples = len(val_dataloader.dataset)

    # Training loop 
    print("Previous training steps: ", model.training_steps, flush=True)
    step = model.training_steps
    for epoch in range(nb_epochs):
        train_expected_ll, train_felbo, train_kl = 0., 0., 0.
        for x, y in train_dataloader:
            # Handle keys
            key, key1, key2 = jax.random.split(key, num=3)

            # Get context points
            x_context = select_context_points(
                n_context_points,
                n_context_sets,
                context_points_maxval=2,
                context_points_minval=-2,
                feature_dim=x.shape[-1],
                key=key1
            )

            # Update likelihood rho
            ll_rho, ll_opt_state, loss_info = update_ll(
                params, 
                ll_opt_state,
                ll_get_params, 
                x, 
                y, 
                key2, 
                ll_opt_update, 
                mc_samples,
                model,
                ll_rho, 
                n_train_samples, 
                step,
                max_grad_norm
            )
            ll_scale = jax.nn.softplus(ll_rho)

            # Update the model
            params, nn_opt_state, loss_info = update_nn(
                params, 
                nn_opt_state,
                nn_get_params, 
                x, 
                y, 
                x_context, 
                key2, 
                nn_opt_update, 
                mc_samples, 
                model, 
                ll_scale, 
                n_train_samples, 
                step,
                max_grad_norm
            )
            train_expected_ll += loss_info["expected_ll"]
            train_felbo += loss_info["felbo"]
            train_kl += loss_info["kl"]
            step += 1

        # Evaluation
        if epoch % 100 == 0 or epoch == nb_epochs-1:
            expected_ll, felbo, kl = 0., 0., 0.
            for x, y in val_dataloader:
                # Handle keys
                key, key1, key2 = jax.random.split(key, num=3)

                # Get context points
                x_context = select_context_points(
                    n_context_points,
                    n_context_sets,
                    context_points_maxval=2,
                    context_points_minval=-2,
                    feature_dim=x.shape[-1],
                    key=key1
                )

                # prediction
                val_loss, val_info = nelbo_fsvi_objective(
                    params, 
                    model, 
                    x, 
                    y, 
                    x_context, 
                    key2, 
                    mc_samples, 
                    False, #is_training
                    ll_scale, 
                    n_val_samples
                )
                expected_ll += val_info["expected_ll"]
                felbo += val_info["felbo"]
                kl += val_info["kl"]

            # Log validation loss 
            print(f"Epoch {epoch} - felbo: {felbo} - expected_ll {expected_ll} - KL {kl}", flush=True)

            # Early stopping
            if felbo > opt_felbo:
                opt_felbo = felbo
                opt_params = params
                opt_ll_scale = ll_scale
                no_improve_count = 0
            else:
                no_improve_count += 100
                if no_improve_count >= early_stopping_patience:
                    params = opt_params
                    ll_scale = opt_ll_scale
                    print("Early stopping.", flush=True)
                    break
    model.training_steps = step
    print("Likelihood scale:", ll_scale, flush=True)

    return params, ll_scale, {"felbo": felbo, "expected_ll": expected_ll, "kl": kl}


@partial(jax.jit, static_argnums=(0,1,2,3,4))
def select_context_points(
	n_context_points,
    n_context_sets,
	context_points_maxval,
    context_points_minval,
	feature_dim,
	key
):
    """
    Select context points uniformly from feature space.

    params:
    - n_context_points (int): number of context points to select.
    - context_points_maxval (float): maximum value of context points.
    - context_points_minval (float): minimum value of context points.
    - feature_dim (int): dimension of the feature space.
    - key: random key.

    returns:
    - context points (jnp.array): context points.
    """
    context_points = jax.random.uniform(
        key=key,
        shape=[n_context_sets, n_context_points, feature_dim],
        minval=context_points_minval,
        maxval=context_points_maxval,
    )
 
    return context_points


@partial(jax.jit, static_argnums=(2,7,8,9,11,13))
def update_nn(
    params, 
    nn_opt_state,
    nn_get_params,
    x_batch,
    y_batch, 
    x_context, 
    key, 
    nn_opt_update,
    mc_samples,
    model,
    ll_scale, 
    n_samples,
    step, 
    max_grad_norm
):
    """
    Gradient update step.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - nn_opt_state (jax.tree_util.pytree): optimizer state.
    - nn_get_params (jax.tree_util.pytree): function to get parameters.
    - x_batch (jnp.array): a batch of input.
    - y_batch (jnp.array): a batch of labels.
    - x_context (jnp.array): a batch of context points to 
        evaluate the regularized KL-divergence term in the ELBO objective.
    - key (jax.random.PRNGKey): JAX random seed.
    - nn_opt_update (callable): optimizer update function.
    - mc_samples (int): number of Monte-Carlo samples for estimating the expected
        log likelihood.
    - model (Model): BNN model.
    - ll_scale (float): scale of the likelihood model. 
    - n_samples (int): total number of training samples
    - step (int): current step.
    - max_grad_norm (float): maximum gradient norm.

    returns:
    - params (jax.tree_util.pytree): updated parameters.
    - nn_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    is_training=True
    grads, other_info = jax.grad(nelbo_fsvi_objective, has_aux=True)(
        params,
        model, 
        x_batch,
        y_batch,
        x_context,
        key,
        mc_samples,
        is_training,
        ll_scale, 
        n_samples
    )
    if max_grad_norm is not None:
        norm = jax.example_libraries.optimizers.l2_norm(grads) 
        clip = lambda x: jnp.where(norm < max_grad_norm, x, x * max_grad_norm / (norm + 1e-6))
        grads = jax.tree_util.tree_map(clip, grads)
    
    nn_opt_state = nn_opt_update(step, grads, nn_opt_state)
    params = nn_get_params(nn_opt_state)

    return params, nn_opt_state, other_info


@partial(jax.jit, static_argnums=(2,6,7,8,10,12))
def update_ll(
    params, 
    ll_opt_state,
    ll_get_params, 
    x, 
    y, 
    key, 
    ll_opt_update, 
    mc_samples,
    model, 
    ll_rho, 
    n_samples, 
    step,
    max_grad_norm
):
    """
    Gradient update step on model likelihood rho.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - ll_opt_state (jax.tree_util.pytree): optimizer state.
    - ll_get_params (jax.tree_util.pytree): function to get parameters.
    - x (jnp.array): a batch of input.
    - y (jnp.array): a batch of labels.
    - key (jax.random.PRNGKey): JAX random seed.
    - ll_opt_update (callable): optimizer update_nn function.
    - mc_samples (int): number of Monte-Carlo samples to estimate the objective.
    - model (Model): BNN model.
    - ll_rho (float): pre-activated likelihood scale. 
    - n_samples (int): total number of training samples
    - step (int): current step.
    - max_grad_norm (float): maximum gradient norm.

    returns:
    - ll_rho (jax.tree_util.pytree): updated parameters.
    - ll_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    is_training = True
    grads, other_info = jax.grad(expected_ll_objective, argnums=7, has_aux=True)(
        params,
        model,
        x,
        y,
        key,
        mc_samples,
        is_training,
        ll_rho, # differentiate wrt ll_rho
        n_samples
    )
    if max_grad_norm is not None:
        norm = jax.example_libraries.optimizers.l2_norm(grads) 
        clip = lambda x: jnp.where(norm < max_grad_norm, x, x * max_grad_norm / (norm + 1e-6))
        grads = jax.tree_util.tree_map(clip, grads)

    ll_opt_state = ll_opt_update(step, grads, ll_opt_state)
    ll_rho = ll_get_params(ll_opt_state)

    return ll_rho, ll_opt_state, other_info 


def evaluate_model(
    key, 
    params, 
    model, 
    dataloader, 
    config
):
    """
    Evaluate the model on the test set.

    params:
    - key (jax.random.PRNGKey): JAX random seed.
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): BNN model.
    - dataloader (DataLoader): data loader.
    - config (dict): configuration.
    
    returns:
    - test_loss (dict): test loss.
    """
    # Read configuration
    assert dataloader.replacement == False, "Data should be sampled without replacement"

    # Get configuration
    mc_samples = 100
    
    # Get likelihood scale
    ll_scale = model.ll_scale

    # Load test data    
    expected_ll, mse = 0., 0.
    for x, y in dataloader:
        # Handle keys
        key, key1 = jax.random.split(key)
        # Prediction
        f_hat = model.predict_f(params, x, key1, mc_samples, is_training=False, stochastic=True)
        expected_ll += jnp.mean(
            jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
            axis=0
        ).sum()
        mse += jnp.sum((f_hat.mean(0).reshape(-1) - y.reshape(-1))**2)
    mse /= len(dataloader.dataset)
    expected_ll /= len(dataloader.dataset)
        
    print(f"Expected log-likelihood: {expected_ll} - MSE: {mse}", flush=True)

    return {"expected_ll": expected_ll, "mse": mse}
