import tensorflow as tf
from tensorboard.backend.event_processing import event_accumulator as ea

from matplotlib import pyplot as plt
from matplotlib import colors as colors

plt.rc('axes', labelsize=16)

logdir=''
acc = ea.EventAccumulator(logdir)
acc.Reload()

scalar_list = acc.Tags()['scalars']
epochs_train_primal = [int(s.step) for s in acc.Scalars('Train_Loss_Epoch')]
epochs_test_primal = [int(s.step) for s in acc.Scalars('Test_Loss')]
primal_train_loss_2 = [s.value for s in acc.Scalars('Train_Loss_Epoch')]
primal_test_loss_2 = [s.value for s in acc.Scalars('Test_Loss')]


logdir='primal_results/fastmri4_recon_sub50_patch80_poisson_paper_results' 
acc = ea.EventAccumulator(logdir)
acc.Reload()

scalar_list = acc.Tags()['scalars']
epochs_train = [int(s.step) for s in acc.Scalars('Train_Loss_Epoch')]
epochs_test = [int(s.step) for s in acc.Scalars('Test_Loss')]
primal_train_loss_3 = [s.value for s in acc.Scalars('Train_Loss_Epoch')]
primal_test_loss_3 = [s.value for s in acc.Scalars('Test_Loss')]

logdir='primal_results/fastmri8_recon_sub50_patch80_poisson_paper_results'
acc = ea.EventAccumulator(logdir)
acc.Reload()

scalar_list = acc.Tags()['scalars']
epochs_train = [int(s.step) for s in acc.Scalars('Train_Loss_Epoch')]
epochs_test = [int(s.step) for s in acc.Scalars('Test_Loss')]
primal_train_loss_4 = [s.value for s in acc.Scalars('Train_Loss_Epoch')]
primal_test_loss_4 = [s.value for s in acc.Scalars('Test_Loss')]


logdir='dual_results/fastmri2_sub50_patch80_poisson_paper_results'
acc = ea.EventAccumulator(logdir)
acc.Reload()

scalar_list = acc.Tags()['scalars']
epochs_train_dual = [int(s.step) for s in acc.Scalars('Train_Loss_Epoch')]
epochs_test_dual = [int(s.step) for s in acc.Scalars('Test_Loss')]
dual_train_loss_2 = [s.value for s in acc.Scalars('Train_Loss_Epoch')]
dual_test_loss_2 = [s.value for s in acc.Scalars('Test_Loss')]


logdir='dual_results/fastmri4_sub50_patch80_poisson_paper_results/'
acc = ea.EventAccumulator(logdir)
acc.Reload()

scalar_list = acc.Tags()['scalars']
epochs_train_dual = [int(s.step) for s in acc.Scalars('Train_Loss_Epoch')]
epochs_test = [int(s.step) for s in acc.Scalars('Test_Loss')]
dual_train_loss_3 = [s.value for s in acc.Scalars('Train_Loss_Epoch')]
dual_test_loss_3 = [s.value for s in acc.Scalars('Test_Loss')]

logdir='dual_results/fastmri8_sub50_patch80_poisson_paper_results'
acc = ea.EventAccumulator(logdir)
acc.Reload()

scalar_list = acc.Tags()['scalars']
epochs_train_dual = [int(s.step) for s in acc.Scalars('Train_Loss_Epoch')]
epochs_test_dual = [int(s.step) for s in acc.Scalars('Test_Loss')]
dual_train_loss_4 = [s.value for s in acc.Scalars('Train_Loss_Epoch')]
dual_test_loss_4 = [s.value for s in acc.Scalars('Test_Loss')]


fin = 25
plt.figure(figsize=(8,5))
# plt.title('Training of 3x3 kernel for MNIST, noise std = 0.5')  
plt.semilogy(epochs_train_primal[0:fin], primal_train_loss_2[0:fin], label='Nonconvex SGD Primal, R=2', linewidth=3, linestyle='--')
plt.semilogy(epochs_train_dual[0:fin], dual_train_loss_2[0:fin], label='Convex SGD Dual, R=2', linewidth=3)
plt.semilogy(epochs_train_primal[0:fin], primal_train_loss_3[0:fin], label='Nonconvex SGD Primal, R=4', linewidth=3, linestyle='--')
plt.semilogy(epochs_train_dual[0:fin], dual_train_loss_3[0:fin], label='Convex SGD Dual, R=4', linewidth=3)
plt.semilogy(epochs_train_primal[0:fin], primal_train_loss_4[0:fin], label='Nonconvex SGD Primal, R=8', linewidth=3, linestyle='--')
plt.semilogy(epochs_train_dual[0:fin], dual_train_loss_4[0:fin], label='Convex SGD Dual, R=8', linewidth=3)
plt.legend(loc='upper right',fontsize = 'small')
plt.ylim(2e-3, 1e-1)
plt.ylabel('Training loss')
plt.xlabel('Epoch')
plt.show()
plt.savefig('train_plots.png')

plt.figure(figsize=(8,5))
#plt.title('Training of 3x3 kernel for MNIST, noise std = 0.5')  
plt.semilogy(epochs_test_primal[0:fin], primal_test_loss_2[0:fin], label='Nonconvex SGD Primal, R=2', linewidth=3, linestyle='--')
plt.semilogy(epochs_test_dual[0:fin], dual_test_loss_2[0:fin], label='Convex SGD Dual, R=2', linewidth=3)
plt.semilogy(epochs_test_primal[0:fin], primal_test_loss_3[0:fin], label='Nonconvex SGD Primal, R=4', linewidth=3, linestyle='--')
plt.semilogy(epochs_test_dual[0:fin], dual_test_loss_3[0:fin], label='Convex SGD Dual, R=4', linewidth=3)
plt.semilogy(epochs_test_primal[0:fin], primal_test_loss_4[0:fin], label='Nonconvex SGD Primal, R=8', linewidth=3, linestyle='--')
plt.semilogy(epochs_test_dual[0:fin], dual_test_loss_4[0:fin], label='Convex SGD Dual, R=8', linewidth=3)
plt.legend(loc='upper left',fontsize = 'small')
plt.ylim(2e-3, 1e-1)
plt.ylabel('Testing loss')
plt.xlabel('Epoch')
plt.show()
plt.savefig('test_plots.png')