from utils import genDataLinear, distDiscrepancy
from linear import linearModel
import numpy as np
import matplotlib.pyplot as plt

num_particle = [20, 50, 100, 200]
num_repeat = 1000
b_res = np.zeros([num_repeat, len(num_particle)])
c_res = np.zeros([num_repeat, len(num_particle)])

LR = linearModel(4)
disc = distDiscrepancy()

for idx, i in enumerate(num_particle):
    dis1 = 0
    dis2 = 0
    for _ in range(num_repeat):
        print('num particle = {}, num repeat = {}'.format(i, _))
        X, Y = genDataLinear(50)
        beta_pop = LR.bootstrap(X, Y, 10000)
        beta_bs = LR.bootstrap(X, Y, i)

        if i == 20:
            lr_use = 0.5 / 50.
        elif i == 200:
            lr_use = 0.25 / i
        else:
            lr_use = 0.25 / i

        beta_ct, ctw = LR.centoridBootstrap(X, Y, i, lr=lr_use)
        dis1 = disc.wasserstain(beta_ct, beta_pop, u_weight=ctw)
        c_res[_, idx] = dis1
        dis2 = disc.wasserstain(beta_bs, beta_pop)
        b_res[_, idx] = dis2


# plot
font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 30,
}
font_legend = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 30,
}

x_axis = np.array(num_particle)
c_avg = np.mean(c_res, axis=0)
c_std = np.std(c_res, axis=0)/np.sqrt(num_repeat)
b_avg = np.mean(b_res, axis=0)
b_std = np.std(b_res, axis=0)/np.sqrt(num_repeat)

color_list = [u'#1f77b4', u'#ff7f0e']

plt.figure(figsize=(8, 5))
plt.plot(x_axis, b_avg, label='Bootstrap', color=color_list[0], linewidth=8)
plt.fill_between(x_axis, b_avg + b_std, b_avg - b_std, color=color_list[0], alpha=0.2, lw=15)

plt.plot(x_axis, c_avg, label='Centroid', color=color_list[1], linewidth=8)
plt.fill_between(x_axis, c_avg + c_std, c_avg - c_std, color=color_list[1], alpha=0.2, lw=15)

plt.grid(color="k", linestyle="-.", alpha=0.2)
plt.xlabel('Num particles', font1)
plt.ylabel('Wasserstein Dist', font1)
my_x_ticks = np.array([20, 50, 100, 200])
my_y_ticks = np.array([0.05, 0.1, 0.15, 0.2])
plt.xticks(my_x_ticks, fontproperties='Times New Roman', fontsize=28)
plt.yticks(my_y_ticks, fontproperties='Times New Roman', fontsize=28)
plt.legend(prop=font_legend)
plt.savefig('Linear_dist.pdf', bbox_inches='tight')
plt.show()
