import jax 

import haiku as hk
import jax.numpy as jnp

from functools import partial

from models.TFSVI.model_utils.bnn import BNN
from models.TFSVI.training_utils.plot_utils import plot_function_samples
from models.TFSVI.training_utils.training import fit_model, evaluate_model


class TFSVI:
    """
    Placeholder for the TFSVI model.
    """
    def __init__(
        self, 
        key, 
        config
    ):
        """
        Initialize the TFSVI model.

        params:
        - key (jax.random.PRNGKey): random key.
        - config (dict): configuration dictionary.
        """
        self.key = key
        self.config = config

        # Build model
        self.model = BNN(
            config["tfsvi"]["architecture"],
            config["tfsvi"]["stochastic_layers"],
            config["tfsvi"]["activation_fn"],
            config["tfsvi"]["likelihood_scale"],
            config["tfsvi"]["prior_scale"]
        )
        
        # Initialize model
        self.params = self.initialize_model()

        print(f'Number of parameters: {hk.data_structures.tree_size(self.params)}', flush=True)


    def initialize_model(self):
        """
        Initialize the BNN model parameters.
        
        returns:
        - params (jax.tree_util.pytree): parameters of the BNN.
        """
        # Handle random key
        self.key, key1 = jax.random.split(self.key)

        # Initialize model
        init_fn, apply_fn = self.model.forward
        x_init = jnp.ones(
            (self.config["data"]["batch_size"], self.config["data"]["feature_dim"])
        )
        params = init_fn(key1, x_init, key1, is_training=True, stochastic=True)

        # Reset training steps
        self.model.training_steps = 0

        return params


    def fit(
        self, 
        train_dataloader,
        val_dataloader
    ):
        """
        Fit the model.

        params:
        - train_dataloader (DataLoader): training dataloader.
        - val_dataloader (DataLoader): validation dataloader.

        returns:
        - val_loss (dict): validation loss.
        """
        # Initialize model
        # self.params = self.initialize_model()
        self.params = self.initialize_model()
        
        # Fit the model
        self.params, self.model.ll_scale, val_loss = fit_model(
            self.key, 
            self.params, 
            self.model, 
            self.config, 
            train_dataloader, 
            val_dataloader
        )

        return val_loss


    def evaluate(
        self, 
        dataloader
    ):
        """
        Evaluate the model.
        
        params:
        - dataloader (DataLoader): dataloader.

        returns:
        - test_loss (dict): test loss.
        """
        loss = evaluate_model(
            self.key, 
            self.params, 
            self.model, 
            dataloader, 
            self.config, 
        )

        return loss


    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_f(
        self, 
        x, 
        key, 
        mc_samples, 
        is_training, 
        stochastic
    ):
        """
        Sample from the function distribution.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): if true, apply local reparametrization trick.
        - stochastic (bool): if true, sample from the posterior distribution.

        returns:
        - f_lin_nn (jnp.array): function samples.
        """
        return self.model.predict_f(
            self.params, 
            x, 
            key, 
            mc_samples,
            is_training, 
            stochastic
        )

    
    @partial(jax.jit, static_argnums=(0,3,4,5))
    def predict_y(
        self, 
        x, 
        key, 
        mc_samples,
        is_training, 
        stochastic
    ):
        """
        Sample from the predictive distribution.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): if true, apply local reparametrization trick.
        - stochastic (bool): if true, sample from the posterior distribution.

        returns:
        - y (jnp.array): function samples.
        """
        return self.model.predict_y(
            self.params, 
            x, 
            key, 
            mc_samples,
            is_training, 
            stochastic
        )


    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_cov(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Return the mean and covariance the linearized functional distribution.
        Estimated from Monte Carlo samples.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): dummy argument.
        
        returns:
        - mean (jnp.array): mean of the distribution.
        - cov (jnp.array): covariance of the distribution.
        """
        f_hat = self.predict_f(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        ).reshape(mc_samples, -1)
        mean = jnp.mean(f_hat, axis=0)
        cov = jnp.cov(f_hat, rowvar=False)

        return mean, cov
    

    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_var(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Return the mean and diagonalized covariance the linearized functional distribution.
        Estimated from Monte Carlo samples.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): dummy argument.
        
        returns:
        - mean (jnp.array): mean of the distribution.
        - diag_cov (jnp.array): diagonal covariance of the distribution.
        """
        f_hat = self.predict_f(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        ).reshape(mc_samples, -1)
        mean = jnp.mean(f_hat, axis=0)
        var = jnp.var(f_hat, axis=0)

        return mean, var
    
    
    def plot(
        self, 
        dataloader
    ):
        """
        Plot function samples.

        params:
        - dataloader (DataLoader): wrapper for dataset
        """
        plot_function_samples(
            self.model, 
            self.params, 
            jax.random.PRNGKey(0), 
            self.config, 
            dataloader
        )
        