from jax import vmap, grad, random
import jax.numpy as jnp
import numpy as np


def eval_analytic_solution(x,epsilon):
  return 1.0-(jnp.exp(x/epsilon)-1.0)/(jnp.exp(1/epsilon)-1.0)

def grad_analytic_solution(x,epsilon):
  return vmap(grad(eval_analytic_solution),in_axes=(0,None))(x,epsilon)




def load_ds(config):

    subsampling_rate=config.train_sub_sampling_rate
    x_nodes=jnp.linspace(0,1,128)
  
    #x=np.linspace(0,L,num_nodes)
    #init_cond=[np.cos(k*np.pi * x / L) for k in data['k']]
   
    keys=random.split(random.PRNGKey(41),100)
    epsilons=jnp.sort(jnp.clip(vmap(random.normal)(jnp.array(keys))*0.2+0.5,0.1,0.9))
    solns=vmap(eval_analytic_solution,in_axes=(None,0))(x_nodes,epsilons)[...,None]
    epsilons=epsilons[...,None]
    
    data_len=solns.shape[0]

    np.random.seed(42)
    indices=np.random.permutation(data_len)
    
    solns=solns[indices]
    epsilons=epsilons[indices]

    train_ds=(solns[:int(data_len*0.9)][::subsampling_rate],epsilons[:int(data_len*0.9)][::subsampling_rate])
    test_ds=(solns[int(data_len*0.9):],epsilons[int(data_len*0.9):])

    return (train_ds, test_ds)
