import h5py
import jax.numpy as jnp
import numpy as np

def load_ds(config):
    #correction_factor=50.
    h_x=h_y=1/127
    
    subsampling_rate=config.train_sub_sampling_rate
    
    with h5py.File(config.train_path) as f:
        data = f['data'][:]
        tensor = f['tensor'][:]
    #x=np.linspace(0,L,num_nodes)
    #init_cond=[np.cos(k*np.pi * x / L) for k in data['k']]
   
    data=data[::subsampling_rate]
    tensor=tensor[::subsampling_rate]


    force=data[:,0][...,None]
    u = data[:,1][...,None]

    
    u_mid=(u[:,:-1, :-1] + u[:,1:, :-1] + u[:,:-1, 1:] + u[:,1:, 1:])/4.

    u_grad_x=(u[:,1:] - u[:,:-1])/h_x
    u_grad_y=(u[:,:, 1:] - u[:,:,:-1])/h_y

    u_grad_x=(u_grad_x[:,:,1:]+u_grad_x[:,:,:-1])/2.
    u_grad_y=(u_grad_y[:,1:]+u_grad_y[:,:-1])/2.
    u_grad=jnp.concat([u_grad_x,u_grad_y],axis=-1)

    vel=tensor[:,1:]

    fluxes=u_grad+u_mid*vel[:,None,None]
    fluxes=fluxes[...,None,:]

    tensor_scales=np.median(np.abs(tensor),axis=0)
    force_scales=np.median(np.linalg.norm(force,axis=(1,2)),axis=0)*np.prod(config.measure)

    train_ds=(u,(force,tensor),fluxes)

    
    
   
    with h5py.File(config.val_path) as f:
        data = f['data'][:]
        tensor = f['tensor'][:]
    #x=np.linspace(0,L,num_nodes)
    #init_cond=[np.cos(k*np.pi * x / L) for k in data['k']]
  
    force=data[:,0][...,None]
    u = data[:,1][...,None]

    
    u_mid=(u[:,:-1, :-1] + u[:,1:, :-1] + u[:,:-1, 1:] + u[:,1:, 1:])/4.

    u_grad_x=(u[:,1:] - u[:,:-1])/h_x
    u_grad_y=(u[:,:,1:] - u[:,:,:-1])/h_y

    u_grad_x=(u_grad_x[:,:,1:]+u_grad_x[:,:,:-1])/2.
    u_grad_y=(u_grad_y[:,1:]+u_grad_y[:,:-1])/2.
    u_grad=jnp.concat([u_grad_x,u_grad_y],axis=-1)

    vel=tensor[:,1:]

    fluxes=u_grad+u_mid*vel[:,None,None]
    fluxes=fluxes[...,None,:]
    
    test_ds=(u,(force,tensor),fluxes)
    #breakpoint()
    
    return (train_ds, test_ds), (force_scales,tensor_scales)

