import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import pandas as pd
import seaborn as sns

sns.set(rc={'figure.figsize':(4, 3)})
rc('text', usetex=True)

none_lr_pow = pd.read_hdf('aug-none-lr-power.h5', 'df')
gauss_none_lr_pow =  pd.read_hdf('aug-gaussian-none-lr-power.h5', 'df')
gauss_exp_lr_pow = pd.read_hdf('aug-gaussian-exp-lr-power.h5', 'df')
gauss_pow_lr_pow = pd.read_hdf('aug-gaussian-power-lr-power.h5', 'df')

none_lr_pow = pd.read_hdf('aug-none-lr-power.h5', 'df')
proj_none_lr_pow =  pd.read_hdf('aug-randproj-none-lr-power.h5', 'df')
proj_exp_lr_pow = pd.read_hdf('aug-randproj-exp-lr-power.h5', 'df')
proj_pow_lr_pow = pd.read_hdf('aug-randproj-power-lr-power.h5', 'df')

none_lr_pow_bs20 = pd.read_hdf('aug-none-lr-power-bs20.h5', 'df')
gauss_none_lr_pow_bs20 =  pd.read_hdf('aug-gaussian-none-lr-power-bs20.h5', 'df')
gauss_exp_lr_pow_bs20 = pd.read_hdf('aug-gaussian-exp-lr-power-bs20.h5', 'df')
gauss_pow_lr_pow_bs20 = pd.read_hdf('aug-gaussian-power-lr-power-bs20.h5', 'df')

print(none_lr_pow.keys())

plot1_df = pd.DataFrame({'Unaugmented': none_lr_pow['Loss'][:10000],
                         'Gauss const': gauss_none_lr_pow['Loss'][:10000],
                        'Gauss exp': gauss_exp_lr_pow['Loss'][:10000],
                        'Gauss pow': gauss_pow_lr_pow['Loss'][:10000]})
plot2_df = pd.DataFrame({'Unaugmented': none_lr_pow['Ortho Norm'][:10000],
                         'Gauss const': gauss_none_lr_pow['Ortho Norm'][:10000],
                        'Gauss exp': gauss_exp_lr_pow['Ortho Norm'][:10000],
                        'Gauss pow': gauss_pow_lr_pow['Ortho Norm'][:10000]})

plt.figure()
plot1 = sns.lineplot(data=plot1_df)
plot1.set_title('Mean Squared Error (GD)')                         
plot1.set_xlabel('Steps')   
plot1.set(yscale='log')
plot1.set(ylabel='')
fig = plot1.get_figure()
plt.tight_layout()
fig.savefig('mse.pdf', bbox_inches='tight')

plt.figure()
plot2 = sns.lineplot(data=plot2_df)
plot2.set_title(r'Orthogonal Weight Norm $\|W_{t, \perp}\|_F$ (GD)')                         
plot2.set_xlabel('Steps')
plot2.legend(loc='upper right')
plot2.set(yscale='log')
plot2.set(ylabel='')
fig = plot2.get_figure()
plt.tight_layout()
fig.savefig('wperp.pdf', bbox_inches='tight')


plot1_df = pd.DataFrame({'Unaugmented': none_lr_pow['Loss'][:10000],
                         'Rand Proj const': proj_none_lr_pow['Loss'][:10000],
                        'Rand Proj exp': proj_exp_lr_pow['Loss'][:10000],
                        'Rand Proj pow': proj_pow_lr_pow['Loss'][:10000]})
plot2_df = pd.DataFrame({'Unaugmented': none_lr_pow['Ortho Norm'][:10000],
                         'Rand Proj const': proj_none_lr_pow['Ortho Norm'][:10000],
                        'Rand Proj exp': proj_exp_lr_pow['Ortho Norm'][:10000],
                        'Rand Proj pow': proj_pow_lr_pow['Ortho Norm'][:10000]})

plt.figure()
plot1 = sns.lineplot(data=plot1_df)
plot1.set_title('Mean Squared Error (GD)')                         
plot1.set_xlabel('Steps')    
plot1.set(yscale='log')
plot1.set(ylabel='')
fig = plot1.get_figure()
plt.tight_layout()
fig.savefig('mse-proj.pdf', bbox_inches='tight')

plt.figure()
plot2 = sns.lineplot(data=plot2_df)
plot2.set_title(r'Orthogonal Weight Norm $\|W_{t, \perp}\|_F$ (GD)')                         
plot2.set_xlabel('Steps')
plot2.legend(loc='upper right')
plot2.set(yscale='log')
plot2.set(ylabel='')
fig = plot2.get_figure()
plt.tight_layout()
fig.savefig('wperp-proj.pdf', bbox_inches='tight')

plot1_df = pd.DataFrame({'Unaugmented': none_lr_pow_bs20['Loss'][:50000],
                         'Gauss const': gauss_none_lr_pow_bs20['Loss'][:50000],
                        'Gauss exp': gauss_exp_lr_pow_bs20['Loss'][:50000],
                        'Gauss pow': gauss_pow_lr_pow_bs20['Loss'][:50000]})
plot2_df = pd.DataFrame({'Unaugmented': none_lr_pow_bs20['Ortho Norm'][:50000],
                         'Gauss const': gauss_none_lr_pow_bs20['Ortho Norm'][:50000],
                         'Gauss exp': gauss_exp_lr_pow_bs20['Ortho Norm'][:50000],
                        'Gauss pow': gauss_pow_lr_pow_bs20['Ortho Norm'][:50000]})


plt.figure()
plot1 = sns.lineplot(data=plot1_df)
plot1.set_title('Mean Squared Error (SGD)')                         
plot1.set_xlabel('Steps')
plot1.legend(loc='upper right')
plot1.set(yscale='log')
plot1.set(ylabel='')
fig = plot1.get_figure()
plt.tight_layout()
fig.savefig('mse-bs20.pdf', bbox_inches='tight')

plt.figure()
plot2 = sns.lineplot(data=plot2_df)
plot2.set_title(r'Orthogonal Weight Norm $\|W_{t, \perp}\|_F$ (SGD)')                         
plot2.set_xlabel('Steps')
plot2.legend(loc='upper right')
plot2.set(yscale='log')
plot2.set(ylabel='')
fig = plot2.get_figure()
plt.tight_layout()
fig.savefig('wperp-bs20.pdf', bbox_inches='tight')
