
import h5py
import jax.numpy as jnp
import numpy as np

def load_ds(config):

    subsampling_rate=config.train_sub_sampling_rate
    
    with h5py.File(config.path) as f:
        u = f['u'][:]
        vel = f['vel'][:]
    #x=np.linspace(0,L,num_nodes)
    #init_cond=[np.cos(k*np.pi * x / L) for k in data['k']]

    h_x=h_y=1/127
   
    u=jnp.squeeze(u)
    u=u[...,None]
    
    vel=vel[...,None]
    vel=jnp.where(vel>1,1,0)


    u_mid=(u[:,:-1, :-1] + u[:,1:, :-1] + u[:,:-1, 1:] + u[:,1:, 1:])/4.

    u_grad_x=(u[:,1:] - u[:,:-1])/h_x
    u_grad_y=(u[:,:, 1:] - u[:,:,:-1])/h_y

    u_grad_x=(u_grad_x[:,:,1:]+u_grad_x[:,:,:-1])/2.
    u_grad_y=(u_grad_y[:,1:]+u_grad_y[:,:-1])/2.
    u_grad=jnp.concat([u_grad_x,u_grad_y],axis=-1)

    #### change of sign
    r_adv=jnp.where(vel>.5,5,0.1)#-jnp.where(vel>.5,5,0.1)
    r_adv=(r_adv[:,:-1, :-1] + r_adv[:,1:, :-1] + r_adv[:,:-1, 1:] + r_adv[:,1:, 1:])/4.

    r_adv=jnp.concat([r_adv,r_adv],axis=-1)

    flux=u_grad+r_adv*u_mid
    flux=flux[...,None,:]

   
    data_len=u.shape[0]

    np.random.seed(42)
    indices=np.random.permutation(data_len)
    
    u=u[indices]
    vel=vel[indices]
    flux=flux[indices]


    train_ds=(u[:int(data_len*0.95)][::subsampling_rate], vel[:int(data_len*0.95)][::subsampling_rate],flux[:int(data_len*0.95)][::subsampling_rate])
    test_ds=(u[int(data_len*0.95):], vel[int(data_len*0.95):],flux[int(data_len*0.95):])

    #breakpoint()
   
    return (train_ds, test_ds)