import seaborn as sns
import pickle
import matplotlib.pyplot as plt
import numpy as np

from train_moons import make_moons

bridge = pickle.load(open("../temp_balance__bridge_fig1.pickle",'rb'))
erm = pickle.load(open("../balance__none_fig1.pickle",'rb'))
original = pickle.load(open("../balance__original_fig1.pickle",'rb'))

import pandas as pd
import numpy as np
loss = list((np.hstack((erm[5],erm[2],
                              bridge[5],bridge[2],
                              original[5],original[2]
                              ))))
meth = ["ERM"]*(len(erm[5]) + len(erm[2])) + ["Bridge"]*(len(bridge[5]) + len(bridge[2])) \
     + ["Mixup"]*(len(original[5]) + len(original[2]))
bound = ["On"]*len(erm[5]) + ["Within"]*len(erm[2])  + \
     ["On"]*len(bridge[5]) + ["Within"]*len(bridge[2]) + \
     ["On"]*len(original[5]) + ["Within"]*len(original[2])

df=pd.DataFrame({'Loss':loss, '':meth,'Boundary':bound})
import seaborn as sns
fig,ax = plt.subplots(figsize=(6,2.5), dpi=200, ncols=2)
ax[1] = sns.stripplot(data=df, y='',x='Loss', hue='Boundary',dodge=1, ax=ax[1], 
                      alpha=0.3, size=4, legend=False)

ax[1].vlines(x=[([erm[5].max(), bridge[5].max(), original[5].max()])], 
             ymin=[-0.3,0.67,1.65], ymax=[0.35,1.32,2.40], linestyles="--", colors="black")
ax[1].spines['top'].set_visible(False)
ax[1].spines['right'].set_visible(False)
ax[1].set_xlim([-0.1,4])
ax[1].set_xticks([0,4])
ax[1].set_xticklabels([0,4])

X,y = make_moons(1000, noise=0.1, shuffle=True, use_noise=False)
x0=X[:,0]
x1=X[:,1]
ax[0].scatter(x0, x1, s=2)

X,y = make_moons(2000, noise=0.065, shuffle=True, use_noise=True)
x0=X[:,0]
x1=X[:,1]
ax[0].scatter(x0, x1, s=2)
ax[0].spines['top'].set_visible(False)
ax[0].spines['right'].set_visible(False)
ax[0].spines['left'].set_visible(False)
ax[0].spines['bottom'].set_visible(False)
ax[0].set_xticks([])
ax[0].set_yticks([])
fig.legend(labels=["On Boundary", "Within Boundary"], ncol=2, bbox_to_anchor=(0.8,1.1))
plt.tight_layout()
#plt.savefig("lossbound.pdf", bbox_inches="tight")