from util import *
from jax import grad, value_and_grad, jit,vmap, random
@jit
def loss_jax(params, df_batch,un):
x,y = df_batch[:,0],df_batch[:,1]
vec = jnp.sort(y) - jnp.sort(params*un)
return jnp.var(vec)
val_and_grad = value_and_grad(loss_jax)
vmap_val_and_grad_inner = vmap(val_and_grad, in_axes=(None,None,1),out_axes=0)
vmap_val_and_grad_outer = vmap(vmap_val_and_grad_inner, in_axes=(None,0,2),out_axes=0)
vmap_val_and_grad_outer = jit(vmap_val_and_grad_outer)
def batch_test(df,resolution,npos):
nghM = get_neighbor_matrix_fixed_num(df, resolution)
batches = get_batches(data=df, neighborM=nghM, resolution=resolution, npos=npos)
batches = jnp.array(batches)
df_batch = batches[0]
batch_sz,_ = df_batch.shape
return batches,batch_sz
def test(batches,key_seed=42,step_sz = 1.0,exp = 200,nrep = 100):
key = random.PRNGKey(key_seed)
df_batch = batches[0]
batch_sz,_ = df_batch.shape
theta_H =0.2
params = theta_H
loss_res = []
t_res = []
gradt_res = []
for j in range(exp):
key, subkey = random.split(key)
un = random.uniform(subkey,shape=(batch_sz,nrep,len(batches)),minval=0.0, maxval=1.0)
loss_val,grad = vmap_val_and_grad_outer(params, batches, un)
ave_loss,ave_grad = np.mean(loss_val),np.mean(grad)
params -= step_sz * ave_grad
loss_res.append(ave_loss)
t_res.append(params)
gradt_res.append(ave_grad)
return loss_res,t_res,gradt_res,params
Generate a dataset with sample size 10000
nsamples = 10000
resolution = 0.001
npos = 50 #100
def f_t(x):
return 0.1*((2.5*x)**3 - x)
# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.uniform(k1,shape=(nsamples, 1),minval=-1, maxval=1)
# y_samples = vmap(f_t)(x_samples)
y_samples = np.array([f_t(x) for x in x_samples])
y_samples += 1.0*random.uniform(knoise,shape=(nsamples, 1),minval=0.0, maxval=1.0)
x= x_samples.reshape(-1)
y= y_samples.reshape(-1)
plt.scatter(x,y,marker='.')
plt.xlabel('X',fontsize=20)
plt.ylabel('Y',fontsize=20)
# plt.title('$Y = X + E_y$ ')
n = nsamples
df_c = np.zeros([n,2])
df_c[:,0],df_c[:,1] = x,y
df_sort_c = sortBycol(df_c,0)
df_rv = np.zeros([n,2])
df_rv[:,0],df_rv[:,1] = y,x
df_sort_rv = sortBycol(df_rv,0)
c_batches,c_batch_sz = batch_test(df_sort_c,resolution,npos)
%timeit c_loss_res, c_t_res, c_gradt_res, params_c = test(c_batches, key_seed = 42, step_sz = 1.0, exp = 100,nrep = 50)
nsamples = 10000
resolution = 0.01
npos = 50 # 100
def f_t(x):
return 0.1*((2.5*x)**3 - x)
# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.uniform(k1,shape=(nsamples, 1),minval=-1, maxval=1)
y_samples = np.array([f_t(x) for x in x_samples])
y_samples += 1.0*random.uniform(knoise,shape=(nsamples, 1),minval=0.0, maxval=1.0)
x= x_samples.reshape(-1)
y= y_samples.reshape(-1)
plt.scatter(x,y,marker='.')
plt.xlabel('X',fontsize=20)
plt.ylabel('Y',fontsize=20)
# plt.title('$Y = X + E_y$ ')
n = nsamples
df_c = np.zeros([n,2])
df_c[:,0],df_c[:,1] = x,y
df_sort_c = sortBycol(df_c,0)
df_rv = np.zeros([n,2])
df_rv[:,0],df_rv[:,1] = y,x
df_sort_rv = sortBycol(df_rv,0)
c_batches,c_batch_sz = batch_test(df_sort_c,resolution,npos)
%timeit c_loss_res, c_t_res, c_gradt_res, params_c = test(c_batches, key_seed = 42, step_sz = 1.0, exp = 100,nrep = 50)
nsamples = 10000
resolution = 0.1
npos = 50 # 100
def f_t(x):
return 0.1*((2.5*x)**3 - x)
# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.uniform(k1,shape=(nsamples, 1),minval=-1, maxval=1)
y_samples = np.array([f_t(x) for x in x_samples])
y_samples += 1.0*random.uniform(knoise,shape=(nsamples, 1),minval=0.0, maxval=1.0)
x= x_samples.reshape(-1)
y= y_samples.reshape(-1)
plt.scatter(x,y,marker='.')
plt.xlabel('X',fontsize=20)
plt.ylabel('Y',fontsize=20)
# plt.title('$Y = X + E_y$ ')
n = nsamples
df_c = np.zeros([n,2])
df_c[:,0],df_c[:,1] = x,y
df_sort_c = sortBycol(df_c,0)
df_rv = np.zeros([n,2])
df_rv[:,0],df_rv[:,1] = y,x
df_sort_rv = sortBycol(df_rv,0)
c_batches,c_batch_sz = batch_test(df_sort_c,resolution,npos)
%timeit c_loss_res, c_t_res, c_gradt_res, params_c = test(c_batches, key_seed = 42, step_sz = 1.0, exp = 100,nrep = 50)