import jax

import numpy as np
import gpjax as gpx
import pandas as pd
import jax.numpy as jnp
import jaxkern as jk

from sklearn.preprocessing import StandardScaler


def standardize_data(
    train_data,
    val_data,
    test_data 
):
    """
    Standardize_data data.

    params:
    - train_data (jnp.array): training data.
    - val_data (jnp.array): validation data.
    - test_data (jnp.array): test data.
    
    returns: 
    - train_data (jnp.array): standardized training data.
    - val_data (jnp.array): standardized validation data.
    - test_data (jnp.array): standardized test data.
    """
    scaler = StandardScaler()
    train_data = scaler.fit_transform(train_data)
    val_data = scaler.transform(val_data)
    test_data = scaler.transform(test_data)

    return train_data, val_data, test_data


def read_uci_data(
    dataset_name
):
    """
    Read UCI data.

    params:
    - dataset_name (string): name of the dataset.

    return:
    - X (jnp.array): feature matrix.
    - y (jnp.array): target matrix.
    """
    file_name = "../Data/" + dataset_name + ".csv"
    df = pd.read_csv(file_name, sep=',', header=None)
    arr = np.random.permutation(df.to_numpy())
    X, y = arr[:,:-1], arr[:,-1].reshape(-1, 1)

    return X, y


def read_toy_data(
    config
):
    """
    Read toy data.

    params:
    - config (dict): configuration dictionary.

    returns:
    - X (jnp.array): feature matrix.
    - y (jnp.array): target matrix.
    """
    key = jax.random.PRNGKey(0)

    # Load configuration
    feature_dim = config["data"]["feature_dim"]
    n_samples = config["data"]["n_samples"]

    # Load data
    if config["data"]["name"] == "truncated_sine":
        X, y = _truncated_sine(key, n_samples, feature_dim)
    elif config["data"]["name"][:2] == "GP":
        kernel = config["data"]["name"].split("_")[1]
        X, y = _GP_data(key, n_samples, kernel)
    else:
        raise Exception("Unknown toy dataset.")
    
    return X, y

    
def _truncated_sine(
    key,
    n_samples,
    feature_dim
):
    """
    Generated data from truncated sine function.

    params:
    - key (jax.random.PRNGKey): random key.
    - n_samples (int): number of samples.
    - feature_dim (int): feature dimension.

    returns:
    - X (jnp.array): feature matrix.
    - y (jnp.array): target matrix.
    """
    key1, key2, key3 = jax.random.split(key, num=3)

    # Features
    X1 = jax.random.uniform(key1, minval=-1, maxval=-0.5, shape=(n_samples//2, feature_dim))
    X2 = jax.random.uniform(key2, minval=0.5, maxval=1, shape=(n_samples//2, feature_dim))
    X = jnp.concatenate([X1, X2], axis=0)

    # Targets
    eps = 0.1*jax.random.normal(key3, shape=(n_samples,))
    y = jnp.sin(2*np.pi*X.mean(axis=-1)) + eps

    # Format
    X = X.reshape(-1, feature_dim)
    y = y.reshape(-1, 1)

    return X, y


def _GP_data(
        key,
        n_samples, 
        kernel_name
    ):
    """
    Generate data from GP prior.

    params:
    - key (jax.random.PRNGKey): random key.
    - n_samples (int): number of samples.
    - kernel_name (string): name of kernel.

    returns:
    - X (jnp.array): feature matrix.
    - y (jnp.array): target matrix.
    """
    key1, key2, key3, key4 = jax.random.split(key, num=4)


    if kernel_name == "RBF":
        kernel = jk.RBF()
    elif kernel_name == "Polynomial":
        kernel = jk.Polynomial(degree=6)
    elif kernel_name == "Matern12":
        kernel = jk.Matern12()
    elif kernel_name == "Matern32":
        kernel = jk.Matern32()
    elif kernel_name == "Matern52":
        kernel = jk.Matern52()
    else:
        raise Exception("Unknown kernel")
    
    # Initialise GP prior
    prior = gpx.Prior(kernel=kernel)
    parameter_state = gpx.initialise(prior, key1)
    parameter_state.params['kernel']['lengthscale'] = 0.5

    # Sample from GP prior
    X = jax.random.uniform(key2, minval=-1, maxval=1, shape=(n_samples,))
    prior_dist = prior(parameter_state.params)(X)
    eps = 0.1*jax.random.normal(key3, shape=(n_samples,1))
    y = prior_dist.sample(seed=key4, sample_shape=(1,)).T + eps

    # Format 
    X = X.reshape(-1, 1)
    y = y.reshape(-1, 1)

    return X, y
    