import jax 
import tree 

import haiku as hk
import jax.numpy as jnp

from jax import eval_shape
from functools import partial

from models.Laplace.model_utils.mlp import MLP
from models.Laplace.training_utils.plot_utils import plot_function_samples
from models.Laplace.training_utils.inference import laplace_inference, split_parameters
from models.Laplace.training_utils.training import fit_model, evaluate_model, fit_to_prior


ACTIVATION_DICT = {
    "tanh": jnp.tanh,
    "relu": jax.nn.relu,
    "lrelu": jax.nn.leaky_relu,
    "elu": jax.nn.elu,
}


class Laplace:
    """
    Placeholder for the Laplace model.
    """
    def __init__(
        self, 
        key, 
        config
    ):
        """
        Initialize the Laplace model.

        params:
        - key (jax.random.PRNGKey): random key.
        - config (dict): configuration dictionary.
        """
        self.key = key
        self.config = config
        self.stochastic_layers = config["laplace"]["stochastic_layers"]
        self.ll_scale = config["laplace"]["likelihood_scale"]
    
        # Initialize model
        self.forward = hk.transform(
            self.make_forward_fn(
                config["laplace"]["architecture"], 
                config["laplace"]["activation_fn"]
            )
        )

        # Initialize model
        self.mean_params = self.initialize_model()
        self.training_steps = 0


    def initialize_model(self):
        """
        Initialize the BNN model parameters.
        
        returns:
        - params (jax.tree_util.pytree): parameters of the BNN.
        """
        # Handle random key
        self.key, key1 = jax.random.split(self.key)

        init_fn, apply_fn = self.forward
        x_init = jnp.ones(
            (self.config["data"]["batch_size"], self.config["data"]["feature_dim"])
        )
        params = init_fn(key1, x_init)

        # Reset training steps
        self.training_steps = 0

        return params

    
    @property
    def apply_fn(self):
        """
        Apply function of the MLP.

        returns:
        - apply (function): apply function.
        """
        return self.forward.apply
    

    def make_forward_fn(
        self, 
        architecture, 
        activation_fn
    ):
        """
        Build foward function.

        params:
        - architecture (list): architecture of the MLP.
        - activation_fn (str): activation function.
        
        returns:
        - forward_fn (function): forward function.
        """
        def forward_fn(x):
            _forward_fn = MLP(
                architecture=architecture,
                activation_fn=ACTIVATION_DICT[activation_fn],
            )
            return _forward_fn(x)

        return forward_fn
        

    def fit(
        self, 
        train_dataloader, 
        val_dataloader
    ):
        """
        Fit the model.

        params:
        - train_dataloader (DataLoader): training dataloader. 
        - val_dataloader (DataLoader): validation dataloader.

        returns:
        - val_loss (dict): validation loss.
        """
        # Initialize model
        self.params = self.initialize_model()
        
        # Posterior mode 
        print("Fitting mean parameters of Laplace approximation...", flush=True)
        self.mean_params, self.ll_scale, val_loss = fit_model(
            self.key, 
            self.mean_params, 
            self, # model
            self.config, 
            train_dataloader,
            val_dataloader
        )

        # Posterior covariance approx. 
        print("Fitting covariance parameters of Laplace approximation...", flush=True)
        self.cov = laplace_inference(
            self, # model
            self.mean_params, 
            train_dataloader, 
            self.key, 
            self.config
        )

        val_loss = self.evaluate(val_dataloader)

        return val_loss


    def evaluate(
        self, 
        dataloader
    ):
        """
        Evaluate the model.

        params:
        - test_dataloader (DataLoader): dataloader.

        returns:
        - loss (dict): loss.
        """
        loss = evaluate_model(
            self.key, 
            self, # model
            dataloader
        )

        return loss


    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_f(
        self, 
        x, 
        key, 
        mc_samples, 
        is_training,
        stochastic
    ):
        """
        Sample from the function distribution.
        
        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): dummy argument.
        - stochastic (bool): dummy argument.

        returns 
        - f_lnn (jnp.array): function samples.
        """
        # Get configuration 
        cov_type = self.config["laplace"]["cov_type"]

        # Split keys
        key1, key2 = jax.random.split(key)

        # Split parameters
        stochastic_params, params_static = self.partition_params()

        # Vectorize parameters
        vec_mean_params, params_unravel = jax.flatten_util.ravel_pytree(stochastic_params)

        # Sample parameters
        if cov_type == "full":
            vec_sample_params = jax.random.multivariate_normal(
                key1, 
                mean=jnp.zeros_like(vec_mean_params), 
                cov=self.cov, 
                shape=(mc_samples,)
            )
        elif cov_type == "diag":
            vec_sample_params = self.cov**0.5 * jax.random.normal(
                key=key1, 
                shape=(mc_samples, vec_mean_params.shape[0])
            )

        # Combine inference and static parameters
        sample_params = jax.vmap(params_unravel)(vec_sample_params)
        sample_params = jax.vmap(hk.data_structures.merge, in_axes=(0, None))(sample_params, params_static)

        # GLM 
        fwd = lambda p: self.apply_fn(p, key2, x)        
        f_lnn = fwd(self.mean_params)
        f_lnn += jax.vmap(jax.jvp, in_axes=(None, None, 0))(fwd, (self.mean_params,), (sample_params,))[1]

        return f_lnn

    
    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_y(
        self, 
        x, 
        key, 
        mc_samples,
        is_training,
        stochastic
    ):
        """
        Sample from predictive distribution.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): dummy argument.
        - stochastic (bool): dummy argument.

        returns:
        - y (jnp.array): function samples.
        """
        key1, key2 = jax.random.split(key)
        
        # Sample from function distribution
        f = self.f_predict(
            x, 
            key1, 
            mc_samples
        ) 

        # Add observation noise
        y = f + self.ll_scale*jax.random.normal(key2, shape=f.shape) 

        return y


    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_cov(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Return the mean and covariance the linearized functional distribution.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): dummy argument.
        
        returns:
        - mean (jnp.array): mean of the functional distribution.
        - cov (jnp.array): covariance of the functional distribution.
        """
        # Split keys
        key1, key2 = jax.random.split(key)

        # Mean 
        mean = self.apply_fn(self.mean_params, key1, x)

        # Covariance
        cov = self.f_distribution_kernel(x, x, key2)

        return mean, cov
    

    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_var(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Return the mean and diagonalized covariance the linearized functional distribution.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): dummy variable.
        
        returns:
        - mean (jnp.array): mean of the distribution.
        - diag_cov (jnp.array): diagonal covariance of the distribution.
        """
        # Split keys
        key1, key2 = jax.random.split(key)

        # Mean
        mean = self.apply_fn(self.mean_params, key1, x)
        
        # Covariance
        x = jnp.expand_dims(x, axis=1)
        kernel_fn = lambda z: self.f_distribution_kernel(z, z, key2)
        diag_cov = jax.vmap(kernel_fn, in_axes=0)(x).reshape(-1)

        return mean, diag_cov
    
    
    @partial(jax.jit, static_argnums=(0,))
    def f_distribution_kernel(
        self, 
        x1, 
        x2, 
        key
    ):
        """
        Compute the kernel induced by the linearized BNN.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x1 (jnp.array): input data.
        - x2 (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.

        returns
        - kernel (jnp.array): kernel matrix.
        """
        # Get configuration
        cov_type = self.config["laplace"]["cov_type"]

        # Parition parameters 
        params_stochastic, params_static = self.partition_params()

        # Define predict function
        predict_fn = lambda p, x: self.apply_fn(self.join_parameters(p, params_static), key, x).reshape(-1)
        f1 = lambda p: predict_fn(p, x1)
        f2 = lambda p: predict_fn(p, x2)

        # Get unravel function
        unravel = jax.flatten_util.ravel_pytree(params_stochastic)[1]

        # Covariance Jacobian product
        if cov_type == "full":
            # Unravel covariance
            pytree_cov = jax.vmap(unravel)(self.cov)

            # Covariance Jacobian.T product
            SJ = jax.vmap(jax.jvp, in_axes=(None, None, 0))(f1, (params_stochastic,), (pytree_cov,))[1] # (p,n,1)
            leaves = jax.tree_util.tree_flatten(SJ)[0]
            JtS = jnp.concatenate([i.reshape(self.cov.shape[0], -1) for i in leaves], axis=-1).T # (n,p)
            pytree_JtS = jax.vmap(unravel)(JtS)
            
            # Jacobian covariance Jacobian product
            kernel = jax.vmap(jax.jvp, in_axes=(None, None, 0))(f2, (params_stochastic,), (pytree_JtS,))[1] # (n,n)
            kernel = kernel.reshape(x1.shape[0], x2.shape[0])
        elif cov_type == "diag":
            # Unravel covariance
            pytree_cov = unravel(self.cov)

            def delta_vjp_jvp(delta):
                delta_vjp = lambda delta: jax.vjp(f2, params_stochastic)[1](delta)
                vj_prod = tree.map_structure(
                    lambda x1, x2: x1 * x2, pytree_cov, delta_vjp(delta)[0]
                )
                leaves, _ = jax.tree_util.tree_flatten(vj_prod)
                tmp = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(params_stochastic), leaves)
                return jax.jvp(f1, (params_stochastic,), (tmp,))[1]
            
            # Jacobian covariance Jacobian product
            fx2 = eval_shape(f2, params_stochastic)
            eye = jnp.eye(x1.shape[0])
            kernel = jax.vmap(jax.linear_transpose(delta_vjp_jvp, fx2))(eye)[0]

        return kernel
    

    def plot(
        self, 
        dataloader
    ):
        """
        Plot function samples.

        params:
        - dataloader (DataLoader): dataloader.
        """
        plot_function_samples(
            self, 
            jax.random.PRNGKey(0), 
            self.config, 
            dataloader
        )
        
    
    def partition_params(self):
        """
        Split parameters into stochastic and non-stochastic.
        
        returns:
        - stochastic_params (jax.tree_util.pytree): stochastic parameters of the BNN.
        - non_stochastic_params (jax.tree_util.pytree): non-stochastic parameters of the BNN.
        """
        stochastic_params, params_static = hk.data_structures.partition(
            lambda m, n, p: self.stochastic_layers[int(m[23:]) if m[23:] else 0], self.mean_params
        )

        return stochastic_params, params_static
    

    def join_parameters(
        self,
        stochastic_params, 
        params_static
    ):
        """
        Join stochastic and non-stochastic parameters.
        
        params:
        - stochastic_params (jax.tree_util.pytree): stochastic parameters.
        - params_static (jax.tree_util.pytree): static parameters.

        returns:
        - params (jax.tree_util.pytree): parameters of the BNN.
        """
        return hk.data_structures.merge(stochastic_params, params_static)