import jax.numpy as jnp
import numpy as np
from jax import vmap
import h5py

def load_ds(config):
    correction_factor=50.
    h_x=h_y=1/128
    
    subsampling_rate=config.train_sub_sampling_rate
    
    with h5py.File(config.train_path) as f:
        data = f['fields'][:]
        tensor = f['tensor'][:]
    #x=np.linspace(0,L,num_nodes)
    #init_cond=[np.cos(k*np.pi * x / L) for k in data['k']]
   
    data=data[::subsampling_rate]
    tensor=tensor[::subsampling_rate]

    tensor=correction_factor*tensor

    force=correction_factor*data[:,0][...,None]
    u = data[:,1][...,None]

    u = jnp.concat([u,u[:,0:1]],axis=1)
    u = jnp.concat([u,u[:,:,0:1]],axis=2)
    

    u_grad_x=(u[:,1:] - u[:,:-1])/h_x
    u_grad_y=(u[:,:, 1:] - u[:,:,:-1])/h_y

    u_grad_x=(u_grad_x[:,:,1:]+u_grad_x[:,:,:-1])/2.
    u_grad_y=(u_grad_y[:,1:]+u_grad_y[:,:-1])/2.
    u_grad=jnp.concat([u_grad_x,u_grad_y],axis=-1)

    visc=vmap(lambda K: jnp.array([[K[0],K[2]],[K[2],K[1]]]))(tensor)

    fluxes=jnp.sum(u_grad[...,None]*visc[:,None,None],axis=-2)
    fluxes=fluxes[...,None,:]

    u=u[:,:-1,:-1]

    tensor_scales=np.median(np.abs(tensor),axis=0)
    force_scales=np.median(np.linalg.norm(force,axis=(1,2)),axis=0)*np.prod(config.measure)

    train_ds=(u,(force,tensor),fluxes)
   
    with h5py.File(config.val_path) as f:
        data = f['fields'][:]
        tensor = f['tensor'][:]
    #x=np.linspace(0,L,num_nodes)
    #init_cond=[np.cos(k*np.pi * x / L) for k in data['k']]

    tensor=correction_factor*tensor
  
    force = correction_factor*data[:,0][...,None]
    u = data[:,1][...,None]

    u = jnp.concat([u,u[:,0:1]],axis=1)
    u = jnp.concat([u,u[:,:,0:1]],axis=2)
    
   
    u_grad_x=(u[:,1:] - u[:,:-1])/h_x
    u_grad_y=(u[:,:,1:] - u[:,:,:-1])/h_y

    u_grad_x=(u_grad_x[:,:,1:]+u_grad_x[:,:,:-1])/2.
    u_grad_y=(u_grad_y[:,1:]+u_grad_y[:,:-1])/2.
    u_grad=jnp.concat([u_grad_x,u_grad_y],axis=-1)

    visc=vmap(lambda K: jnp.array([[K[0],K[2]],[K[2],K[1]]]))(tensor)

    fluxes=jnp.sum(u_grad[...,None]*visc[:,None,None],axis=-2)
    fluxes=fluxes[...,None,:]

    u=u[:,:-1,:-1]
    
    test_ds=(u,(force,tensor),fluxes)

    return (train_ds, test_ds), (force_scales,tensor_scales)