"""Plot log during training
"""
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import numpy as np
import math

csfont = { 'fontname': 'Times New Roman', 'fontsize': 25 }
ftsize = 25

log_files = [ "../log/ablation/verbose_mm_damp.log",
			  "../log/ablation/verbose_mm_nodamp.log",
			  "../log/ablation/verbose_nomm_damp.log",
			  "../log/ablation/verbose_nomm_nodamp.log"]
# log_files = [ "../log/cifar10/verbose_SGD.log",
#  			  "../log/ablation/verbose_nomm_nodamp.log" ]

legend_str = ["with momentum and damping",
			  "with only momentum",
			  "with only damping",
			  "with no momentum or damping"]
# legend_str = [ "SGD",
#			   "BFGS" ]

step = 20
timescale = 3600 # hours
color = [ 'b', 'k', 'm', 'r' ]
realtime = False
max_epoch = 100

fig = plt.figure( 1, figsize=(16,12) )
# fig = plt.figure( 1, figsize=(16,8) )
ax1 = fig.add_subplot( 2,1,1 )
ax2 = fig.add_subplot( 2,1,2 )
hold = True
for log_file, clr in zip( log_files, color ):
	with open( log_file ) as fp:
		timestamp = []
		timestamp_test = []
		timestamp_test_avg = []
		batchstamp = []
		epochstamp_test = []
		epochstamp_test_avg = []
		acc1_train = []
		acc5_train = []
		acc1_train_avg = []
		acc5_train_avg = []
		loss_train = []
		loss_train_avg = []
		acc1_test = []
		acc5_test = []
		acc1_test_avg = []
		acc5_test_avg = []
		loss_test = []
		loss_test_avg = []

		line = fp.readline()
		cnt = 1
	
		time_point = 0
		batch_point = 0
		epoch_point = 0
		start_train = False
		while line:
			line = fp.readline()
	
			if "Epoch" in line:
				start_train = True
	
				splits = line.split()
	
				indx = splits.index( "Time" )
				duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				time_point += duration
				timestamp.append( time_point/timescale )

				batch_point += step
				batchstamp.append( batch_point/1000 )
	
				indx = splits.index( "Loss" )
				val = float( splits[ indx+1 ] )
				if math.isnan( val ):
					val = loss_train_avg[-1]
				loss_train.append( val )
				val = float( splits[ indx+2 ][ 1:-1 ] )
				if math.isnan(val) or val > 3:
					val = 3.0
				loss_train_avg.append( val )
	
				indx = splits.index( "Acc@1" )
				acc1_train.append( float( splits[ indx+1 ] ) )
				acc1_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
				indx = splits.index( "Acc@5" )
				acc5_train.append( float( splits[ indx+1 ] ) )
				acc5_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
			if "Test" in line and start_train:
				splits = line.split()

				indx = splits.index( "Time" )
				duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				time_point += duration
	
				timestamp_test.append( time_point/timescale )
				epochstamp_test.append( epoch_point )
				if "40/40" in line:
					epoch_point += 1
					timestamp_test_avg.append( timestamp[ -1 ] )
					epochstamp_test_avg.append( epoch_point-1 )
	
				indx = splits.index( "Loss" )
				loss_test.append( float( splits[ indx+1 ] ) )
				if "40/40" in line:
					val = float(splits[indx + 2][1:-1])
					if val is None:
						val = loss_test_avg[-1]
					loss_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
				indx = splits.index( "Acc@1" )
				acc1_test.append( float( splits[ indx+1 ] ) )
				if "40/40" in line:
					val = float( splits[ indx+2 ][ 1:-1 ] )
					if val is None:
						val = acc1_test_avg[-1]
					acc1_test_avg.append( val )
	
				indx = splits.index( "Acc@5" )
				acc5_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
			cnt += 1
	
		# Train loss curve
		if realtime:
			stamp = timestamp
			stamp_test = timestamp_test
			stamp_test_avg = timestamp_test_avg
		else:
			stamp = batchstamp
			stamp_test = epochstamp_test
			stamp_test_avg = epochstamp_test_avg

		# plt.figure( figsize=(10, 6))
		# hold = True
		# plt.subplot( 2,2,1 )
		# plt.plot( stamp, loss_train, clr, alpha=0.2 )
		ax1.plot( stamp, loss_train_avg, clr+"-", linewidth=2 )

		ax1.tick_params( axis='x', labelsize=20 )
		ax1.tick_params( axis='y', labelsize=20 )
		ax1.set_xticks([0, 10, 20, 30, 40])
		ax1.set_xticklabels(['0', '10k', '20k', '30k', '40k'])
		ax1.set_ylabel( "Training loss", **csfont )
		ax1.set_yscale( "log" )
		ax1.set_xlabel( "Iterations", **csfont )
		ax1.legend( legend_str,
					fontsize=ftsize,
					loc='lower left' )
		# plt.savefig( "log_loss.png" )
		
		# Test acc1 curve
		# plt.subplot( 2,2,2 )
		# plt.plot( stamp_test, acc1_test, clr, alpha=0.2 )
		ax2.plot( stamp_test_avg, np.array(acc1_test_avg), clr+"-", linewidth=2 )

		ax2.tick_params(axis='x', labelsize=20)
		ax2.tick_params(axis='y', labelsize=20)
		ax2.set_ylabel( "Test acc.", **csfont )
		ax2.set_xlabel( "Epochs", **csfont )
		#ax2.legend( legend_str,
		#			fontsize=15 )
		if "verbose_mm_nodamp" in log_file:
			ax2.annotate(s='Abnormal eigenvalues \nin the Hessian',
						 xy=(160, 88), xytext=(160, 50), color='black',
						 fontfamily="Times New Roman",fontsize=ftsize,
					 	 verticalalignment='center',
					 	 arrowprops=dict(arrowstyle='->', color='black'))
		if "verbose_nomm_damp" in log_file:
			ax2.annotate(s='Large stochastic noise',
						 xy=(125, 58), xytext=(80, 45), color='purple',
						 fontfamily="Times New Roman", fontsize=ftsize,
					 	 verticalalignment='center',
					 	 arrowprops=dict(arrowstyle='->', color='purple'))

		if "verbose_nomm_nodamp" in log_file:
			ax2.annotate(s='Diverges easily',
						 xy=(1, 30), xytext=(40, 40), color='red',
						 fontfamily="Times New Roman", fontsize=ftsize,
					 	 verticalalignment='center',
					 	 arrowprops=dict(arrowstyle='->', color='red'))

plt.show()