from util import *
from jax import grad, value_and_grad, jit,vmap, random
@jit
def loss_jax(params, df_batch,un):
x,y = df_batch[:,0],df_batch[:,1]
vec = jnp.sort(y-params['w']*x) - jnp.sort(params['theta']*un)
return jnp.var(vec)
val_and_grad = value_and_grad(loss_jax)
vmap_val_and_grad_inner = vmap(val_and_grad, in_axes=(None,None,1),out_axes=0)
vmap_val_and_grad_outer = vmap(vmap_val_and_grad_inner, in_axes=(None,0,2),out_axes=0)
vmap_val_and_grad_outer = jit(vmap_val_and_grad_outer)
def batch_test(df,resolution,npos):
nghM = get_neighbor_matrix_fixed_num(df, resolution)
batches = get_batches(data=df, neighborM=nghM, resolution=resolution, npos=npos)
batches = jnp.array(batches)
df_batch = batches[0]
batch_sz,_ = df_batch.shape
return batches,batch_sz
def test(batches,key_seed=42,step_sz = 1.0,exp = 200,nrep = 100):
key = random.PRNGKey(key_seed)
df_batch = batches[0]
batch_sz,_ = df_batch.shape
theta =0.2
w = 0.1
params = {'w': w, 'theta':theta}
loss_res = []
params_res = []
w_res = []
t_res = []
gradt_res = []
gradw_res = []
for j in range(exp):
key, subkey = random.split(key)
un = random.uniform(subkey,shape=(batch_sz,nrep,len(batches)),minval=0.0, maxval=1.0)
loss_val,grad = vmap_val_and_grad_outer(params, batches, un)
ave_loss = np.mean(loss_val)
ave_grad = tree_map(np.mean, grad)
params['w'] -= 1*step_sz * ave_grad['w']
params['theta'] -= step_sz * ave_grad['theta']
loss_res.append(ave_loss)
w_res.append(params['w'])
t_res.append(params['theta'])
gradw_res.append(ave_grad['w'])
gradt_res.append(ave_grad['theta'])
if j%10==0:
sys.stdout.write("\rDoing thing %i" % j)
return loss_res,w_res,t_res,gradw_res,gradt_res
# Set problem dimensions
nsamples = 100
resolution = 0.6
npos = 100
def f_t(x):
# if x < 0:
# return 0.5*x**3 -x
# else:
# return 1 - 0.5*x**3 + x
return 1.0*x
# return jnp.sin(4*x)
# Generate random ground truth W and b
key = random.PRNGKey(10)
k1, k2 = random.split(key)
# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.uniform(k1,shape=(nsamples, 1),minval=-1, maxval=1)
y_samples = np.array([f_t(x) for x in x_samples])
y_samples += 1.0*random.uniform(knoise,shape=(nsamples, 1),minval=0.0, maxval=1.0)
x= x_samples.reshape(-1)
y= y_samples.reshape(-1)
plt.scatter(x,y)
n = nsamples
df_c = np.zeros([n,2])
df_c[:,0],df_c[:,1] = x,y
df_sort_c = sortBycol(df_c,0)
df_rv = np.zeros([n,2])
df_rv[:,0],df_rv[:,1] = y,x
df_sort_rv = sortBycol(df_rv,0)
c_batches,c_batch_sz = batch_test(df_sort_c,resolution,npos)
c_loss_res, c_w_res,c_t_res,c_gradw_res,c_gradt_res= test(c_batches, key_seed = 42, step_sz =1.0, exp = 100,nrep = 50)
loss_c = np.mean(c_loss_res[-10:])/np.mean(c_t_res[-10:])
rv_batches,rv_batch_sz = batch_test(df_sort_rv,resolution,npos)
rv_loss_res, rv_w_res,rv_t_res,rv_gradw_res,rv_gradt_res= test(rv_batches,key_seed = 42, step_sz = 1.0, exp = 100,nrep = 50)
loss_rv = np.mean(rv_loss_res[-10:])/np.mean(rv_t_res[-10:])
print(loss_c)
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.title('loss')
plt.plot(np.arange(0,len(c_loss_res),1), c_loss_res)
plt.subplot(1,3,2)
plt.title('theta')
plt.plot(np.arange(0,len(c_t_res),1), c_t_res)
plt.subplot(1,3,3)
plt.title('w')
plt.plot(np.arange(0,len( c_w_res),1), c_w_res)
print(loss_rv)
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.title('loss')
plt.plot(np.arange(0,len(rv_loss_res),1), rv_loss_res)
plt.subplot(1,3,2)
plt.title('theta')
plt.plot(np.arange(0,len(rv_t_res),1), rv_t_res)
plt.subplot(1,3,3)
plt.title('w')
plt.plot(np.arange(0,len( rv_w_res),1), rv_w_res)
x_axis = [0.05,0.1,0.15,0.2,0.3,0.5,0.6,0.7,0.8,0.9,0.99]
y_iden = [ 0.26, 0.64, 0.68, 0.58, 0.54, 0.47, 0.4 , 0.31, 0.24, 0.1, 0.08]
y_debias=[0.28, 0.65, 0.66, 0.64, 0.64, 0.66, 0.66, 0.62, 0.69, 0.7, 0.57]
plt.figure(figsize=(8,6))
plt.plot(x_axis,y_debias,'*-',markersize=20,label='$D_{var}(\mathbf{v})$'+'+debiasing',c='b')
plt.plot(x_axis,y_iden,'v-',markersize=20,label='$D_{var}(\mathbf{v})$',c='r')
plt.legend(loc=0,fontsize=20)
plt.title('')
plt.xlabel('Batch size',fontsize=20)
plt.ylabel('Accuracy',fontsize=20)
plt.tick_params(axis='x', labelsize=16)
plt.tick_params(axis='y', labelsize=16)
plt.savefig('syn2.pdf')
x_axis = [0.05,0.1,0.15,0.2,0.3,0.5,0.6,0.7,0.8,0.9,0.99]
y_iden = [0.75070566,0.9194967,0.9679439,1.027566,1.1283402,1.3811135,1.5308298,1.6863382,1.8529547,2.0301418,2.1752474]
e_iden = [0.040009316,0.043040432,0.044457022,0.047091447,0.05011093,0.084434845,0.109450355,0.12603875,0.13183458,0.13365275,0.13085239]
y_debias=[0.74767524, 0.89844733, 0.9278445, 0.95092875, 0.968776, 0.9839461, 0.9877516, 0.98825914, 0.98868203, 0.98957497, 0.98911023]
e_debias = [0.040188283, 0.0434733, 0.0454809, 0.04889683, 0.05098795, 0.05749743, 0.056389857, 0.052207127, 0.05019227, 0.048613697, 0.049608827]
plt.figure(figsize=(8,6))
plt.errorbar(x_axis,y_debias, e_debias,linestyle='-', marker='*',markersize=15,label='$D_{var}(\mathbf{v})$'+'+debiasing',c='b')
plt.errorbar(x_axis,y_iden,e_iden,marker = 'v',linestyle='-',markersize=15,label='$D_{var}(\mathbf{v})$',c='r')
plt.legend(loc=0,fontsize=20)
plt.title('')
plt.xlabel('Batch size',fontsize=20)
plt.ylabel('Estimated parameter: $\widehat{\\theta}$',fontsize=20)
plt.tick_params(axis='x', labelsize=16)
plt.tick_params(axis='y', labelsize=16)
plt.savefig('syn3.pdf')