import jax 

import haiku as hk
import jax.numpy as jnp

from functools import partial

from models.MFVI.model_utils.bnn import BNN
from models.MFVI.training_utils.plot_utils import plot_function_samples
from models.MFVI.training_utils.training import fit_model, evaluate_model


class MFVI:
    """
    Mean-Field Variational Inference (MFVI) BNN.
    """
    def __init__(
        self, 
        key, 
        config
    ):
        """
        Initialize the MFVI model.

        params:
        - key (jax.random.PRNGKey): random key.
        - config (dict): configuration dictionary.
        """
        self.key = key
        self.config = config

        # Build model
        self.model = BNN(
            config["mfvi"]["architecture"],
            config["mfvi"]["stochastic_layers"],
            config["mfvi"]["activation_fn"],
            config["mfvi"]["likelihood_scale"],
            config["mfvi"]["prior_scale"]
        )
        
        # Initialize model
        self.params = self.initialize_model()
        
        print(f'Number of parameters: {hk.data_structures.tree_size(self.params)}', flush=True)


    def initialize_model(self):
        """
        Initialize the BNN model parameters.
        
        returns:
        - params (jax.tree_util.pytree): parameters of the BNN.
        """
        # Handle random key
        self.key, key1 = jax.random.split(self.key)

        # Initialize model
        init_fn, apply_fn = self.model.forward
        x_init = jnp.ones(
            (self.config["data"]["batch_size"], self.config["data"]["feature_dim"])
        )
        params = init_fn(
            key1, 
            x_init, 
            key1, 
            is_training=True, 
            stochastic=True
        )

        # Reset training steps
        self.model.training_steps = 0

        return params


    def fit(
        self, 
        train_dataloader, 
        val_dataloader
    ):
        """
        Fit the model.

        params:
        - train_dataloader (DataLoader): training dataloader.
        - val_dataloader (DataLoader): validation dataloader.

        returns:
        - val_loss (dict): validation loss.
        """
        # Initialize model
        self.params = self.initialize_model()
        
        # Fit model
        self.params, self.model.ll_scale, val_loss = fit_model(
            self.key, 
            self.params, 
            self.model, 
            self.config, 
            train_dataloader, 
            val_dataloader
        )

        return val_loss
        

    def evaluate(
        self, 
        dataloader
    ):
        """
        Evaluate the model.

        params:
        - dataloader (DataLoader): dataloader.

        returns:
        - test_loss (dict): test loss.
        """
        test_loss = evaluate_model(
            self.key, 
            self.params, 
            self.model, 
            dataloader
        )

        return test_loss


    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_f(
        self, 
        x, 
        key, 
        mc_samples, 
        is_training, 
        stochastic
    ):
        """
        Sample from function distribution.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): whether to use training mode.
        - stochastic (bool): whether to sample from the posterior.

        returns:
        - f_samples (jnp.array): function samples.
        """
        return self.model.predict_f(
            self.params, 
            x, 
            key, 
            mc_samples, 
            is_training, 
            stochastic
        )[0]

    
    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_y(
        self, 
        x, 
        key, 
        mc_samples, 
        is_training,
        stochastic
    ):
        """
        Sample from predictive distribution.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): whether to use training mode.
        - stochastic (bool): whether to sample from the posterior.

        returns:
        - y_samples (jnp.array): function samples.
        """
        return self.model.predict_y(
            self.params, 
            x, 
            key, 
            mc_samples, 
            is_training, 
            stochastic
        )[0]


    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_cov(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Estimate the mean and covariance the functional distribution
        from samples as there is no closed form density over functions.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        
        returns:
        - mean (jnp.array): mean of the distribution.
        - cov (jnp.array): covariance of the distribution.
        """
        f_hat = self.predict_f(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        ).reshape(mc_samples, -1)
        mean = jnp.mean(f_hat, axis=0)
        cov = jnp.cov(f_hat, rowvar=False)

        return mean, cov
    
    
    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_var(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Estimate the mean and variance the functional distribution
        from samples as there is no closed form density over functions.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - mean (jnp.array): mean of the distribution.
        - var (jnp.array): variance of the distribution.
        """
        f_hat = self.predict_f(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        ).reshape(mc_samples, -1)
        mean = jnp.mean(f_hat, axis=0)
        var = jnp.var(f_hat, axis=0)

        return mean, var
    

    def plot(
        self, 
        dataloader
    ):
        """
        Plot function samples.

        params:
        - dataloader (DataLoader): data loader.
        """
        plot_function_samples(
            self.model, 
            self.params, 
            jax.random.PRNGKey(0), 
            self.config, 
            dataloader
        )
        