import jax 

import haiku as hk
import jax.numpy as jnp

from functools import partial

from models.FVI.model_utils.bnn import BNN
from models.FVI.model_utils.prior import Prior
from models.FVI.training_utils.plot_utils import plot_function_samples
from models.FVI.training_utils.training import fit_model, evaluate_model


class FVI:
    """
    Placeholder for the FVI model.
    Code adapted to JAX from https://github.com/ssydasheng/FBNN/tree/master.
    """
    def __init__(
        self, 
        key, 
        config
    ):
        """
        Initialize the FVI model.

        params:
        - key (jax.random.PRNGKey): random key.
        - config (dict): configuration dictionary.
        """
        self.key = key
        self.config = config
    
        # Build model
        self.model = BNN(
            config["fvi"]["architecture"],
            config["fvi"]["stochastic_layers"],
            config["fvi"]["activation_fn"],
            config["fvi"]["likelihood_scale"]
        )
        
        # Initialize model
        self.params = self.initialize_model()
        
        print(f'Number of parameters: {hk.data_structures.tree_size(self.params)}', flush=True)


    def initialize_model(self):
        """
        Initialize the BNN model parameters.
        
        returns:
        - params (jax.tree_util.pytree): parameters of the BNN.
        """
        # Handle random key
        self.key, key1 = jax.random.split(self.key)

        # Initialize model
        init_fn, apply_fn = self.model.forward
        x_init = jnp.ones(
            (self.config["data"]["batch_size"], self.config["data"]["feature_dim"])
        )
        params = init_fn(key1, x_init, key1)

        # Reset training steps
        self.model.training_steps = 0

        return params


    def fit(
        self, 
        train_dataloader, 
        val_dataloader
    ):
        """
        Fit the model.

        params:
        - train_dataloader (DataLoader): train dataloader.
        - val_dataloader (DataLoader): validation dataloader.

        returns:
        - val_loss (dict): validation loss.
        """        
        # Initialize model
        self.params = self.initialize_model()

        # Load Prior
        self.prior = Prior(
            self.key,
            train_dataloader,
            self.config
        )

        # Fit the model
        self.params, self.model.ll_scale, val_loss = fit_model(
            self.key, 
            self.params, 
            self.model, 
            self.config, 
            train_dataloader, 
            val_dataloader,
            self.prior
        )

        return val_loss


    def evaluate(
        self, 
        dataloader
    ):
        """
        Evaluate the model.
        
        params:
        - dataloader (DataLoader): test dataloader.

        returns:
        - test_loss (dict): test loss.
        """
        test_loss = evaluate_model(
            self.key, 
            self.params, 
            self.model, 
            dataloader
        )

        return test_loss


    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_f(
        self, 
        x, 
        key, 
        mc_samples, 
        is_training,
        stochastic
    ):
        """
        Sample from the linearized function distribution.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): dummy variable for compatibility.
        - stochastic (bool): dummy variable for compatibility.

        returns:
        - f_lin_nn (jnp.ndarray): function samples.
        """
        return self.model.predict_f(
            self.params, 
            x, 
            key, 
            mc_samples
        )

    
    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_y(
        self, 
        x, 
        key, 
        mc_samples,
        is_training,
        stochastic
    ):
        """
        Sample from the linearized predictive distribution.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): dummy variable for compatibility.
        - stochastic (bool): dummy variable for compatibility.
        
        returns:
        - y (jnp.ndarray): function samples.
        """
        return self.model.predict_y(
            self.params, 
            x, 
            key, 
            mc_samples
        )


    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_cov(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Estimate the mean and variance the functional distribution
        from samples as there is no closed form density over functions.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): dummy variable for compatibility.
        
        returns:
        - mean (jnp.ndarray): mean of the distribution.
        - cov (jnp.ndarray): covariance of the distribution.
        """
        f_hat = self.predict_f(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        ).reshape(mc_samples, -1)
        mean = jnp.mean(f_hat, axis=0)
        cov = jnp.cov(f_hat, rowvar=False)

        return mean, cov
    

    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_var(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Estimate the mean and variance the functional distribution
        from samples as there is no closed form density over functions.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): dummy variable for compatibility.
        
        returns:
        - mean (jnp.ndarray): mean of the distribution.
        - diag_cov (jnp.ndarray): diagonal covariance of the distribution.
        """
        f_hat = self.predict_f(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        ).reshape(mc_samples, -1)
        mean = jnp.mean(f_hat, axis=0)
        var = jnp.var(f_hat, axis=0)

        return mean, var
    

    def plot(
        self, 
        dataloader
    ):
        """
        Plot function samples.

        params:
        - dataloader (DataLoader): dataloader.
        """
        plot_function_samples(
            self.model, 
            self.params, 
            jax.random.PRNGKey(0), 
            self.config, 
            dataloader
        )
        