import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt


def plot_function_samples(
    model,  
    key, 
    config, 
    dataloader
):
    """
    Plot function samples from the model.

    params:
    - model (Model): model.
    - params (jax.tree_util.pytree): model parameters.
    - key (jax.random.PRNGKey): random key.
    - config (dict): model configuration.
    - dataloader (Dataloader): dataloader.
    """
    # Get config 
    kernel = config["gp"]['prior']['kernel']
    dataset = config["data"]["name"]

    # Keys
    key, sub_key = jax.random.split(key)

    # Sample functions
    x = jnp.arange(-2, 2, 0.01).reshape(-1, 1)

    f_samples = model.predict_f(
        x, 
        sub_key, 
        mc_samples=100,
        is_training=False,
        stochastic=True
    )
    
    # Format input 
    x = x.reshape(-1)
    f_samples = jnp.squeeze(f_samples)
    pred_mean = jnp.squeeze(f_samples.mean(0))
    pred_std = jnp.squeeze(f_samples.std(0))

    # Plot train data
    X_train, y_train = dataloader.dataset.get_data()
    plt.scatter(
        X_train.reshape(-1), 
        y_train.reshape(-1), 
        label="Train data"
    )

    # Plot predictive mean
    plt.plot(
        x, 
        pred_mean, 
        label="Predictive mean", 
        c="r", 
        linewidth=2
    )

    # Plot predictive std dev
    plt.fill_between(
        x, 
        pred_mean-pred_std, 
        pred_mean+pred_std, 
        color="r", 
        alpha=0.3, 
        label="Predictive 1-std-dev"
    )

    # Plot individual mean functions
    for i in range(f_samples.shape[0]):
        plt.plot(x, f_samples[i,:].reshape(-1), c="g", alpha=0.2)

    plt.ylim(-5, 5)
    plt.savefig(
        f"GP_{kernel}_{dataset}.pdf", 
        bbox_inches='tight'
    )