import jax

from functools import partial

from models.GP.gp import GP
from models.GFSVI.gfsvi import GFSVI
from models.MFVI.mfvi import MFVI
from models.TFSVI.tfsvi import TFSVI
from models.Laplace.laplace import Laplace
from models.FVI.fvi import FVI


class Model:
    """
    Abstract class for the model.
    """

    def __init__(
        self, 
        key, 
        config
    ):
        """
        Initialize model.

        params:
        - key (jax.random.PRNGKey): random key.
        - config (dict): configuration dictionary.
        """
        if config["model"]["name"] == "GFSVI":
            self.model = GFSVI(key, config)
        elif config["model"]["name"] == "MFVI":
            self.model = MFVI(key, config)
        elif config["model"]["name"] == "GP":
            self.model = GP(key, config)
        elif config["model"]["name"] == "Laplace":
            self.model = Laplace(key, config)
        elif config["model"]["name"] == "TFSVI":
            self.model = TFSVI(key, config)
        elif config["model"]["name"] == "FVI":
            self.model = FVI(key, config)
        else:
            raise NotImplementedError


    def fit(
        self, 
        train_dataloader, 
        val_dataloader
    ):
        """
        Fit the model.

        params:
        - train_dataloader (DataLoader): training dataloader.
        - val_dataloader (DataLoader): validation dataloader.

        returns:
        - val_loss (dict): validation loss.
        """
        return self.model.fit(train_dataloader, val_dataloader)
    

    def evaluate(
        self, 
        dataloader
    ):
        """
        Evaluate the model.

        params:
        - dataloader (DataLoader): data.

        returns:
        - test_loss (dict): test loss.
        """
        return self.model.evaluate(dataloader)


    @partial(jax.jit, static_argnums=(0,3))
    def predict_f(
        self, 
        x, 
        key, 
        mc_samples 
    ):
        """
        Sample from function distribution.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - f_samples (jnp.array): function samples.
        """
        return self.model.predict_f(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        )


    @partial(jax.jit, static_argnums=(0,3))
    def predict_y(
        self, 
        x, 
        key, 
        mc_samples
    ):
        """
        Sample from predictive distribution.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - y_samples (jnp.array): function samples.
        """
        return self.model.predict_y(
            x, 
            key, 
            mc_samples, 
            is_training=False, 
            stochastic=True
        )

    
    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_cov(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Return the mean and covariance the functional distribution. 
        For MFVI, these values are estimated from samples as there is 
        no closed form density over functions.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        
        returns:
        - mean (jnp.array): mean of the distribution.
        - cov (jnp.array): covariance of the distribution.
        """
        return self.model.f_distribution_mean_cov(
            x, 
            key, 
            mc_samples
        )
    

    @partial(jax.jit, static_argnums=(0,3))
    def f_distribution_mean_var(
        self,
        x, 
        key, 
        mc_samples
    ):
        """
        Return the mean and variance the functional distribution. 
        For MFVI, these values are estimated from samples as there is 
        no closed form density over functions.

        params:
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - mean (jnp.array): mean of the distribution.
        - var (jnp.array): variance of the distribution.
        """
        return self.model.f_distribution_mean_var(
            x, 
            key, 
            mc_samples
        )
    
    
    def plot(
        self, 
        dataloader
    ):
        """
        Plot function samples.

        params:
        - dataloader (DataLoader): dataloader.
        """
        self.model.plot(
            dataloader
        )
        