import jax 
import copy

import jaxkern as jk
import jax.numpy as jnp
import jax.scipy as jsp

from functools import partial
from jax.example_libraries.optimizers import adam 


KERNEL_DICT = {
    "RBF": jk.RBF, 
    "Matern12": jk.Matern12, 
    "Matern32": jk.Matern32, 
    "Matern52": jk.Matern52,
    "Linear": jk.Linear, 
    "RationalQuadratic": jk.RationalQuadratic, 
}


class Prior:
    """
    Gaussian process prior.
    """
    def __init__(
        self, 
        key, 
        dataloader,
        config,
    ):
        """
        Initialize the prior.

        params:
        - key (jax.random.PRNGKey): a random key.
        - dataloader (dataloader): wrapper for the dataset.
        - config (dict): configuration dictionary.
        """
        self.key = key 
        self.kernel_name = config["fvi"]["prior"]["kernel"]
        self.feature_dim = config["data"]["feature_dim"]
        self.kernel_params = config["fvi"]["prior"]["parameters"]
        self.parameter_tuning = config["fvi"]["prior"]["parameter_tuning"]

        # Initialize the kernel
        self.kernel = KERNEL_DICT[self.kernel_name](active_dims=list(range(self.feature_dim)))
        self.params = self.kernel.init_params(self.key)

        # Set prior parameters
        if self.parameter_tuning:
            print("Tuning prior parameters...", flush=True)
            self._tune_parameters(config, dataloader)
        else:
            self.params["variance"] = self.kernel_params["variance"]
            if self.kernel_name in ["RBF", "Matern12", "Matern32", "Matern52"]:
                self.params["lengthscale"] = self.kernel_params["lengthscale"]
            elif self.kernel_name == "RationalQuadratic":
                self.params["lengthscale"] = self.kernel_params["lengthscale"]
                self.params["alpha"] = self.kernel_params["alpha"]
            
        
    def __call__(
        self, 
        x=None
    ):
        """
        Compute the prior mean and covariance of the prior.

        params:
        - x (jnp.ndarray): the input data.
        
        returns:
        - prior_mean (jnp.ndarray): the prior mean.
        - prior_cov (jnp.ndarray): the prior covariance.
        """
        prior_mean = jnp.zeros(x.shape[0])
        prior_cov = self.kernel.gram(self.params, x).to_dense()
        
        return prior_mean, prior_cov
    

    def _tune_parameters(
        self,   
        config, 
        dataloader 
    ):
        """
        Select prior parameters via maximum marginal likelihood maximization with SGD.

        params:
        - config (dict): configuration dictionary.
        - dataloader (dataloader): wrapper for the dataset.
        """
        # Get configuration
        nb_epochs = config["fvi"]["prior"]["nb_epochs"]
        ll_scale = config["fvi"]["likelihood_scale"]
        early_stopping_patience = config["fvi"]["training"]["patience"]

        # Initialize optimizers
        opt_init, opt_update, get_params = adam(
            config["fvi"]["prior"]["lr"],
            config["fvi"]["prior"]["b1"],
            config["fvi"]["prior"]["b2"],
            config["fvi"]["prior"]["eps"]
        )
        z_params = copy.deepcopy(self.params)
        opt_state = opt_init(z_params)

        # Set the dataloader to sample dataset with replacement
        dataloader.set_replacement_mode(replacement=True)

        # Early stopping initialization
        opt_n_mll, no_improve_count = jnp.inf, 0
        opt_params = None 

        # Training loop
        step = 0
        for epoch in range(nb_epochs):
            n_mll = 0
            for x, y in dataloader:
                # Update parameters
                z_params, opt_state, loss = self.update(
                    self._negative_marginal_likelihood, 
                    z_params,
                    opt_state,
                    get_params,
                    x,
                    y,
                    opt_update,
                    ll_scale,
                    step
                )
                # Update negative marginal likelihood
                n_mll += loss
                step += 1

            # Log negative marginal likelihood
            if epoch % 100 == 0 or epoch == nb_epochs - 1:    
                print(f"Epoch {epoch} - negative marginal likelihood: {n_mll}", flush=True)
            
            # Early stopping
            if n_mll < opt_n_mll:
                opt_n_mll = n_mll
                opt_params = z_params
                no_improve_count = 0
            else:
                no_improve_count += 1
                if no_improve_count >= early_stopping_patience:
                    z_params = opt_params
                    print("Early stopping.", flush=True)
                    break
            
        # Enforce parameter positivity constraints
        self.params = self._enforce_positivity_constaints(z_params)

        # Set the dataloader to sample data without replacement
        dataloader.set_replacement_mode(replacement=False)

        print(f"Optimal parameters: {self.params}", flush=True)


    @partial(jax.jit, static_argnums=(0,1,4,7,8))
    def update(
        self,
        loss, 
        z_params, 
        opt_state,
        get_params,
        x_batch,
        y_batch, 
        opt_update,
        ll_scale, 
        step
    ):
        """
        Gradient update step.

        params:
        - loss (callable): loss function.
        - z_params (dict): unconstrainted prior parameters.
        - opt_state (jax.tree_util.pytree): optimizer state.
        - get_params (jax.tree_util.pytree): function to get parameters.
        - x_batch (jnp.ndarray): input features.
        - y_batch (jnp.ndarray): targets.
        - opt_update (callable): optimizer update function.
        - ll_scale (float): scale of the likelihood model. 
        - step (int): current step.

        returns:
        - z_params (jax.tree_util.pytree): updated unconstrained parameters.
        - opt_state (jax.tree_util.pytree): updated optimizer state.
        - other_info (dict): other information.
        """
        loss_value, grads = jax.value_and_grad(loss)(
            z_params,
            ll_scale,
            x_batch,
            y_batch
        )
        opt_state = opt_update(step, grads, opt_state)
        z_params = get_params(opt_state)

        return z_params, opt_state, loss_value

    
    @partial(jax.jit, static_argnums=(0,2))
    def _negative_marginal_likelihood(
        self, 
        z_params,
        ll_scale,   
        x, 
        y
    ):
        """
        Compute the negative marginal likelihood of the batch.

        params: 
        - z_params (dict): unconstained prior parameters.
        - ll_scale (float): scale of the likelihood model.
        - x (jnp.ndarray): input features.
        - y (jnp.ndarray): targets.
        
        returns:
        - mll (float): the negative marginal likelihood of the batch.
        """
        # Enforce parameter positivity constraints
        z_params = self._enforce_positivity_constaints(z_params)

        # Compute the prior mean and covariance
        prior_mean = jnp.zeros(x.shape[0])
        prior_cov = self.kernel.gram(z_params, x).to_dense() + 1e-6 * jnp.eye(x.shape[0])

        # Compute the marginal likelihood mean and covariance
        evidence_mean = prior_mean
        evidence_cov = prior_cov + ll_scale**2 * jnp.eye(prior_cov.shape[0])

        # Compute the marginal likelihood
        mll = jsp.stats.multivariate_normal.logpdf(y.reshape(-1), mean=evidence_mean, cov=evidence_cov)

        return -mll.sum() 
    

    def _enforce_positivity_constaints(
        self, 
        params
    ):
        """
        Enforce the prior parameters positivity constraints.

        params:
        - params (dict): unconstained prior parameters.

        returns:
        - params (dict): constrained prior parameters.
        """
        params["variance"] = jax.nn.softplus(params["variance"])
        if self.kernel_name in ["RBF", "Matern12", "Matern32", "Matern52"]:
            params["lengthscale"] = jax.nn.softplus(params["lengthscale"])
        elif self.kernel_name == "RationalQuadratic":
            params["lengthscale"] = jax.nn.softplus(params["lengthscale"])
            params["alpha"] = jax.nn.softplus(params["alpha"])

        return params