import jax 

import jax.numpy as jnp
import jax.scipy as jsp

from functools import partial
from jax.example_libraries.optimizers import adam 

from models.FVI.training_utils.objective import neg_felbo_objective, expected_ll_objective


def fit_model(
    key, 
    params, 
    model, 
    config, 
    train_dataloader, 
    val_dataloader,
    prior
):
    """
    Fit the model.

    params:
    - key (jax.random.PRNGKey): random key.
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): BNN.
    - config (dict): configuration dictionary.
    - train_dataloader (DataLoader): train data loader.
    - val_dataloader (DataLoader): val data loader.
    - prior (Prior): prior.

    returns: 
    - params (jax.tree_util.pytree): parameters of the BNN.
    - loss (dict): validation losses.
    """
    assert train_dataloader.replacement == val_dataloader.replacement == False, "Data should be sampled without replacement"

    # Read configuration
    mc_samples = config["fvi"]["training"]["mc_samples"] 
    nb_epochs = config["fvi"]["training"]["nb_epochs"]
    n_context_points = config["fvi"]["training"]["n_context_points"]
    context_selection = config["fvi"]["training"]["context_selection"]
    context_points_minval = config["fvi"]["training"]["min_context_val"]
    context_points_maxval = config["fvi"]["training"]["max_context_val"]
    early_stopping_patience = config["fvi"]["training"]["patience"]
    max_grad_val = config["fvi"]["training"]["max_grad_val"]

    # Initialize parameter optimizer
    nn_opt_init, nn_opt_update, nn_get_params = adam(
        config["fvi"]["training"]["lr"],
        config["fvi"]["training"]["b1"],
        config["fvi"]["training"]["b2"],
        config["fvi"]["training"]["eps"]
    )
    nn_opt_state = nn_opt_init(params)

    # Initialize likelihood rho optimizer
    ll_scale = config["fvi"]["likelihood_scale"]
    ll_rho = jnp.log(jnp.exp(ll_scale)-1)
    ll_opt_init, ll_opt_update, ll_get_params = adam(
        config["fvi"]["training"]["lr"],
        config["fvi"]["training"]["b1"],
        config["fvi"]["training"]["b2"],
        config["fvi"]["training"]["eps"]
    )
    ll_opt_state = ll_opt_init(ll_rho)

    # Early stopping initialization
    opt_gelbo, no_improve_count = jnp.NINF, 0
    opt_params, opt_ll_scale = None, None

    # Training loop 
    print("Previous training steps: ", model.training_steps, flush=True)
    step = model.training_steps
    for epoch in range(nb_epochs):
        # Train losses 
        train_expected_ll, train_felbo, train_kl_div = 0., 0., 0.
        for x, y in train_dataloader:
            # Handle keys
            key, key1, key2, key3 = jax.random.split(key, num=4)

            # Get context points
            x_context = select_context_points(
                n_context_points,
                context_selection,
                context_points_maxval,
                context_points_minval,
                feature_dim=x.shape[-1],
                key=key1
            )

            # Update likelihood rho
            ll_rho, ll_opt_state, loss_info = update_ll(
                params, 
                ll_opt_state,
                ll_get_params, 
                x, 
                y, 
                key2, 
                ll_opt_update, 
                model, 
                ll_rho,
                mc_samples, 
                step,
                max_grad_val
            )
            ll_scale = jax.nn.softplus(ll_rho)

            # Update the model
            params, nn_opt_state, loss_info = update_nn(
                params, 
                nn_opt_state,
                nn_get_params, 
                prior,
                x, 
                y, 
                x_context, 
                key3, 
                nn_opt_update, 
                model, 
                ll_scale, 
                mc_samples,
                step,
                max_grad_val
            )

            train_expected_ll += loss_info["expected_ll"]
            train_felbo += loss_info["felbo"]
            train_kl_div += loss_info["kl_div"]
            step += 1

        # Evaluation
        if epoch % 100 == 0 or epoch == nb_epochs-1:
            expected_ll, felbo, kl_div = 0., 0., 0.
            for x, y in val_dataloader:
                # Handle keys
                key, key1, key2 = jax.random.split(key, num=3)

                # Get context points
                x_context = select_context_points(
                    n_context_points,
                    context_selection,
                    context_points_maxval,
                    context_points_minval,
                    feature_dim=x.shape[-1],
                    key=key1
                )

                # Prediction
                val_loss, val_info = neg_felbo_objective(
                    params, 
                    model, 
                    prior, 
                    x, 
                    y, 
                    x_context, 
                    key2, 
                    ll_scale, 
                    mc_samples
                )
                expected_ll += val_info["expected_ll"]
                felbo += val_info["felbo"]
                kl_div += val_info["kl_div"]

            print(f"Epoch {epoch} - felbo: {felbo} - expected log-likelihood {expected_ll} - fKL {kl_div}", flush=True)

            # Early stopping
            if felbo > opt_gelbo:
                opt_gelbo = felbo
                opt_params = params
                opt_ll_scale = ll_scale
                no_improve_count = 0
            else:
                no_improve_count += 100
                if no_improve_count >= early_stopping_patience:
                    params = opt_params
                    ll_scale = opt_ll_scale
                    print("Early stopping.", flush=True)
                    break
    model.training_steps = step
    print("Likelihood scale:", ll_scale, flush=True)

    return params, ll_scale, {"felbo": felbo, "expect_ll": expected_ll, "kl_div": kl_div}



@partial(jax.jit, static_argnums=(0,1,2,3,4))
def select_context_points(
	n_context_points,
    context_selection,
	context_points_maxval,
    context_points_minval,
	feature_dim,
	key,
):
    """
    Select context points.

    params:
    - n_context_points (int): number of context points to select.
    - context_selection (str): context selection method.
    - context_points_maxval (float): maximum value of context points.
    - context_points_minval (float): minimum value of context points.
    - feature_dim (int): dimension of the feature space.
    - key: random key.

    returns:
    - context points (jnp.array): context points.
    """
    if context_selection == "random":
        context_points = jax.random.uniform(
            key=key,
            shape=[n_context_points, feature_dim],
            minval=context_points_minval,
            maxval=context_points_maxval,
        )
    elif context_selection == "grid":
        assert feature_dim == 2, "Grid context selection only works for 2D features."
        x1 = jnp.linspace(-1, 1, 28)
        x2 = jnp.linspace(-1, 1, 28)
        x = jnp.meshgrid(x1, x2, indexing='ij')
        context_points = jnp.stack(x, axis=-1).reshape(-1, 2)
    
    return context_points


@partial(jax.jit, static_argnums=(2,3,8,9,11,13))
def update_nn(
    params, 
    nn_opt_state,
    nn_get_params,
    prior,
    x_batch,
    y_batch, 
    x_context, 
    key, 
    nn_opt_update,
    model,
    ll_scale, 
    mc_samples,
    step,
    max_grad_val
):
    """
    Gradient update step on model params.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - nn_opt_state (jax.tree_util.pytree): optimizer state.
    - nn_get_params (jax.tree_util.pytree): function to get parameters.
    - prior (Prior): prior.
    - x_batch (jnp.array): a batch of input.
    - y_batch (jnp.array): a batch of labels.
    - x_context (jnp.array): a batch of context points to 
        evaluate the regularized KL-divergence term in the ELBO objective.
    - key (jax.random.PRNGKey): JAX random seed.
    - nn_opt_update (callable): optimizer update_nn function.
    - model (Model): BNN model.
    - ll_scale (float): scale of the likelihood model. 
    - mc_samples (int): number of Monte Carlo samples.
    - step (int): current step.
    - max_grad_val (float): maximum gradient norm.

    returns:
    - params (jax.tree_util.pytree): updated parameters.
    - nn_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    grads, other_info = jax.grad(neg_felbo_objective, has_aux=True)(
        params,
        model, 
        prior,
        x_batch,
        y_batch,
        x_context,
        key,
        ll_scale, 
        mc_samples
    )
    if max_grad_val is not None:
        clip = lambda x: jnp.clip(x, a_min=-max_grad_val, a_max=max_grad_val)
        grads = jax.tree_util.tree_map(clip, grads)
        
    nn_opt_state = nn_opt_update(step, grads, nn_opt_state)
    params = nn_get_params(nn_opt_state)

    return params, nn_opt_state, other_info


@partial(jax.jit, static_argnums=(2,6,7,9,11))
def update_ll(
    params, 
    ll_opt_state,
    ll_get_params, 
    x, 
    y, 
    key, 
    ll_opt_update, 
    model, 
    ll_rho, 
    mc_samples, 
    step,
    max_grad_val
):
    """
    Gradient update step on model likelihood rho.

    params:
    - params (jax.tree_util.pytree): parameters of the BNN.
    - ll_opt_state (jax.tree_util.pytree): optimizer state.
    - ll_get_params (jax.tree_util.pytree): function to get parameters.
    - prior (Prior): prior.
    - x (jnp.array): a batch of input.
    - y (jnp.array): a batch of labels.
    - key (jax.random.PRNGKey): JAX random seed.
    - ll_opt_update (callable): optimizer update_nn function.
    - model (Model): BNN model.
    - ll_rho (float): pre-activated likelihood scale. 
    - mc_samples (int): number of Monte Carlo samples.
    - step (int): current step.
    - max_grad_val (float): maximum gradient norm.

    returns:
    - ll_rho (jax.tree_util.pytree): updated parameters.
    - ll_opt_state (jax.tree_util.pytree): updated optimizer state.
    - other_info (dict): other information.
    """
    grads, other_info = jax.grad(expected_ll_objective, argnums=5, has_aux=True)(
        params,
        model, 
        x,
        y,
        key,
        ll_rho, # differentiate wrt ll_rho
        mc_samples
    )
    if max_grad_val is not None:
        clip = lambda x: jnp.clip(x, a_min=-max_grad_val, a_max=max_grad_val)
        grads = jax.tree_util.tree_map(clip, grads)

    ll_opt_state = ll_opt_update(step, grads, ll_opt_state)
    ll_rho = ll_get_params(ll_opt_state)

    return ll_rho, ll_opt_state, other_info 


def evaluate_model(
    key, 
    params, 
    model, 
    dataloader
):
    """
    Evaluate the model.

    params:
    - key (jax.random.PRNGKey): JAX random seed.
    - params (jax.tree_util.pytree): parameters of the BNN.
    - model (Model): BNN model.
    - dataloader (DataLoader): data loader.

    returns:
    - test_loss (dict): test loss.
    """
    assert dataloader.replacement == False, "Data should be sampled without replacement"

    # Read configuration
    mc_samples = 100 

    # Get likelihood scale
    ll_scale = model.ll_scale

    # Load test data    
    expected_ll, mse = 0., 0.
    for x, y in dataloader:
        # Handle keys
        key, key1 = jax.random.split(key)
        # Prediction
        f_hat = model.predict_f(params, x, key1, mc_samples)
        expected_ll += jnp.mean(
            jsp.stats.norm.logpdf(y, loc=f_hat, scale=ll_scale),
            axis=0
        ).sum()
        mse += jnp.sum((f_hat.mean(0).reshape(-1) - y.reshape(-1))**2)

    mse /= len(dataloader.dataset)
    expected_ll /= len(dataloader.dataset)
        
    print(f"Expected log-likelihood: {expected_ll} - MSE: {mse}", flush=True)

    return {"expected_ll": expected_ll, "mse": mse}


