In [1]:
from util import *
from jax import grad, value_and_grad, jit,vmap, random
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/jax/config.py:163: UserWarning: enable_omnistaging() is a no-op in JAX versions 0.2.12 and higher;
see https://github.com/google/jax/blob/master/design_notes/omnistaging.md
  "enable_omnistaging() is a no-op in JAX versions 0.2.12 and higher;\n"
In [2]:
@jit
def loss_jax(params, df_batch,un):
    x,y = df_batch[:,0],df_batch[:,1]
    vec = jnp.sort(y) - jnp.sort(params*un)
    return jnp.var(vec)

val_and_grad = value_and_grad(loss_jax)
vmap_val_and_grad_inner = vmap(val_and_grad, in_axes=(None,None,1),out_axes=0)  
vmap_val_and_grad_outer = vmap(vmap_val_and_grad_inner, in_axes=(None,0,2),out_axes=0)  
vmap_val_and_grad_outer = jit(vmap_val_and_grad_outer)
In [3]:
def batch_test(df,resolution,npos):
    nghM = get_neighbor_matrix_fixed_num(df, resolution)
    batches = get_batches(data=df, neighborM=nghM, resolution=resolution, npos=npos)
    batches = jnp.array(batches)

    df_batch = batches[0]
    batch_sz,_ = df_batch.shape
    return batches,batch_sz
In [4]:
def test(batches,key_seed=42,step_sz = 1.0,exp = 200,nrep = 100):
    key = random.PRNGKey(key_seed)
    df_batch = batches[0]
    batch_sz,_ = df_batch.shape
    theta_H =0.2
    params = theta_H

    loss_res = []
    t_res = []
    gradt_res = []

    for j in range(exp):
        key, subkey = random.split(key)
        un = random.laplace(subkey,shape=(batch_sz,nrep,len(batches)))
        loss_val,grad = vmap_val_and_grad_outer(params, batches, un)
        ave_loss,ave_grad = np.mean(loss_val),np.mean(grad)
        params -= step_sz * ave_grad
        loss_res.append(ave_loss)
        t_res.append(params)
        gradt_res.append(ave_grad)
        if j%10==0:
            sys.stdout.write("\rDoing thing %i" % j)

    return loss_res,t_res,gradt_res,params

Data generation

  1. $Y = X + E_y$
  2. $Y = X + 0.5 X^3+ E_y$

$E_y$ follows Laplace distribution.

In [5]:
nsamples = 100
resolution = 0.2
npos = 50

def f_t(x):
    return x
    # return x + 0.5*x**3
In [6]:
# Generate random ground truth W and b
key = random.PRNGKey(7)
k1, k2 = random.split(key)

# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.uniform(k1,shape=(nsamples, 1),minval=-1, maxval=1)

# y_samples = vmap(f_t)(x_samples)
y_samples = np.array([f_t(x) for x in x_samples])
# noise = random.uniform(k2,shape=(nsamples, 1),minval=0, maxval=1)
noise = random.laplace(k2,shape=(nsamples, 1))
y_samples += noise
x= x_samples.reshape(-1)
y= y_samples.reshape(-1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Preprocessing

We only keep the data within 2 standard deviation

In [7]:
x= normalize(x)
y= normalize(y)
x,y = x.reshape(-1),y.reshape(-1)

rangex = 2
rangey = 2
x = x[y<rangey]
y = y[y<rangey]
x = x[y>-rangey]
y = y[y>-rangey]
y = y[x<rangex]
x = x[x<rangex]
y = y[x>-rangex]
x = x[x>-rangex]

ind = np.arange(0,len(x),1)
key, subkey = random.split(key)
ind = random.permutation(key, ind)


x = x[ind]
y = y[ind]
In [8]:
n = len(x)
plt.scatter(x,y,marker='.')
plt.xlabel('X',fontsize=20)
plt.ylabel('Y',fontsize=20)

df_c = np.zeros([n,2])
df_c[:,0],df_c[:,1] = x,y
df_sort_c = sortBycol(df_c,0)

df_rv = np.zeros([n,2])
df_rv[:,0],df_rv[:,1] = y,x
df_sort_rv = sortBycol(df_rv,0)
In [9]:
if n < 50:
    npos=n
In [10]:
c_batches,c_batch_sz = batch_test(df_sort_c,resolution,npos)
c_loss_res, c_t_res, c_gradt_res, params_c = test(c_batches, key_seed = 42, step_sz = 0.05, exp = 100,nrep = 50)
loss_c = np.mean(c_loss_res[-10:])/params_c

rv_batches,rv_batch_sz = batch_test(df_sort_rv,resolution,npos)
rv_loss_res,rv_t_res,rv_gradt_res,params_rv = test(rv_batches,key_seed = 42, step_sz = 0.05, exp = 100,nrep = 50)
loss_rv = np.mean(rv_loss_res[-10:])/params_rv
Doing thing 90
In [11]:
print(loss_c)
plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.title('loss')
plt.plot(np.arange(0,len(c_loss_res),1), c_loss_res)

plt.subplot(1,3,2)
plt.title('theta')
plt.plot(np.arange(0,len(c_t_res),1), c_t_res)

plt.subplot(1,3,3)
plt.title('grad theta')
plt.plot(np.arange(0,len( c_gradt_res),1), c_gradt_res)
0.2052097
Out[11]:
[<matplotlib.lines.Line2D at 0x7f9fa0748828>]
In [12]:
print(loss_rv)
plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.title('loss')
plt.plot(np.arange(0,len(rv_loss_res),1), rv_loss_res)

plt.subplot(1,3,2)
plt.title('theta')
plt.plot(np.arange(0,len(rv_t_res),1), rv_t_res)

plt.subplot(1,3,3)
plt.title('grad theta')
plt.plot(np.arange(0,len( rv_gradt_res),1), rv_gradt_res)
0.3279853
Out[12]:
[<matplotlib.lines.Line2D at 0x7f9fb0d6b710>]

Plot for Sec. 6

In [13]:
x_axis = [25,50,75,100,200,500]
y_linear = [0.69 ,0.96 ,0.96 ,0.99 ,1.0 ,1.0]
y_poli = [0.67, 0.89, 0.96, 0.99, 1.0, 1.0]
In [14]:
plt.figure(figsize=(12,10))
plt.plot(x_axis,y_poli,'o-',markersize=20)


# plt.legend(loc=0,fontsize=40)
plt.title('$Y = X+0.5X^3 + E_y$',fontsize=40)
plt.ylim([0.0,1.05])
plt.xlabel('Sample size',fontsize=40)
plt.ylabel('Accuracy',fontsize=40)

plt.tick_params(axis='x', labelsize=32)
plt.tick_params(axis='y', labelsize=32)
plt.grid()

# plt.savefig('syn_cmp_poli.pdf')
In [15]:
plt.figure(figsize=(12,10))
plt.plot(x_axis,y_linear,'o-',markersize=20)


# plt.legend(loc=0,fontsize=40)
plt.title('$Y = X + E_y$',fontsize=40)
# plt.title('$Y = 0.1((2.5X)^3-X) + E_y$',fontsize=40)
# plt.title('$Y = f_{piece}(X) + E_y$',fontsize=40)
plt.ylim([0.0,1.05])
plt.xlabel('Sample size',fontsize=40)
plt.ylabel('Accuracy',fontsize=40)

plt.tick_params(axis='x', labelsize=32)
plt.tick_params(axis='y', labelsize=32)
plt.grid()

# plt.savefig('syn_cmp_linear.pdf')
In [ ]: