from brl.utils import *

xs = [i*1500 for i in interval(0,4)]
xs += [i*5000 for i in interval(1,20)]
xs.sort()

pv_bs1 = {x:RV(str(x), save_data=True) for x in xs}
pv_bs4 = {x:RV(str(x), save_data=True) for x in xs}

for seed in [0]:
    plot = LinePlot('','')
    plot.load_data('laeq_seed{}/BRL/brl/envs/laeq.png.pkl'.format(seed))

    for x,y in zip(plot.baselines[0]['x'], plot.baselines[0]['y']):
        if x not in xs: continue
        pv_bs1[x].append(y)

    for x,y in zip(plot.baselines[1]['x'], plot.baselines[1]['y']):
        if x not in xs: continue
        pv_bs4[x].append(y)

data_bs1 = [(x, pv_bs1[x].mean(), pv_bs1[x].mean_ci()) for x in pv_bs1]
data_bs4 = [(x, pv_bs4[x].mean(), pv_bs4[x].mean_ci()) for x in pv_bs4]
print('bs1:')
for pt in data_bs1: print(pt)
print('bs4:')
for pt in data_bs4: print(pt)

plot = LinePlot('step', 'corpus BLEU')
plot.add_line(label='beam size = 4',
              x=[pt[0] for pt in data_bs4],
              y=[pt[1] for pt in data_bs4],
              y_range=[pt[2] for pt in data_bs4])

plot.add_line(label='beam size = 1',
              x=[pt[0] for pt in data_bs1],
              y=[pt[1] for pt in data_bs1],
              y_range=[pt[2] for pt in data_bs1])


plot.show()
plot.ax.axhline(y=27.3, color='gray', linestyle='--')
exp_name = os.getcwd().split('/')[-1]
num_seeds = pv_bs1[0].size()
plot.output(exp_name + '_' + str(num_seeds) + 'seeds.png')

plot_blank = LinePlot('','', title='close this window to see the curves')
plot_blank.show(block=True)

