import numpy as np
import jax.numpy as jnp
from scipy.special import erf

from jax import config
config.update("jax_enable_x64", True)

f_synthetic_data_design_struct = lambda a: a ** 2 - 0.3

def synthetic_data_design(data_size, U_range = (-1, 1)):    
    U = np.random.uniform(U_range[0], U_range[1], (data_size, ))
    epsilon_1 = np.random.normal(0, 0.1, (data_size,))
    epsilon_2 = np.random.normal(0, 0.05, (data_size,))
    A = erf(U) + epsilon_1
    W = np.exp(U) + epsilon_2
    Y = np.sin(2 * np.pi * U) + A ** 2 - 0.3

    A = jnp.array(A, dtype = jnp.float64).reshape(-1, 1)
    W = jnp.array(W, dtype = jnp.float64).reshape(-1, 1)
    Y = jnp.array(Y, dtype = jnp.float64).reshape(-1, 1)
    return A, Y, W

def synthetic_noisy_data_design(data_size, noise_var = 0.1, U_range = (-1, 1)):    
    U = np.random.uniform(U_range[0], U_range[1], (data_size, ))
    epsilon_1 = np.random.normal(0, 0.1, (data_size,))
    epsilon_2 = np.random.normal(0, 0.05, (data_size,))
    A = erf(U) + epsilon_1
    W = np.exp(U) + epsilon_2
    output_noise = np.random.normal(0, noise_var, (data_size,))
    Y = np.sin(2 * np.pi * U) + A ** 2 - 0.3 + output_noise

    A = jnp.array(A, dtype = jnp.float64).reshape(-1, 1)
    W = jnp.array(W, dtype = jnp.float64).reshape(-1, 1)
    Y = jnp.array(Y, dtype = jnp.float64).reshape(-1, 1)
    return A, Y, W, output_noise