"""Plot log during training
"""
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import numpy as np

csfont = { 'fontname': 'Times New Roman', 'fontsize': 25 }
ftsize = 25

log_files = [ "../log/flower102/verbose_SGD.log",
			  "../log/flower102/verbose_SLIM.log"]

step = 20
timescale = 3600 # hours
color = [ 'b', 'k' ]
realtime = False
average = True
max_epoch = 100

fig = plt.figure( 1, figsize=(16,8) )
ax1 = fig.add_subplot( 1,2,1 )
ax2 = fig.add_subplot( 1,2,2 )
hold = True
for log_file, clr in zip( log_files, color ):
	with open( log_file ) as fp:
		timestamp = []
		timestamp_test = []
		timestamp_test_avg = []
		batchstamp = []
		epochstamp_test = []
		epochstamp_test_avg = []
		acc1_train = []
		acc5_train = []
		acc1_train_avg = []
		acc5_train_avg = []
		loss_train = []
		loss_train_avg = []
		acc1_test = []
		acc5_test = []
		acc1_test_avg = []
		acc5_test_avg = []
		loss_test = []
		loss_test_avg = []

		line = fp.readline()
		cnt = 1
	
		time_point = 0
		batch_point = 0
		epoch_point = 0
		start_train = False
		while line:
			line = fp.readline()
	
			if "Epoch" in line:
				start_train = True
	
				splits = line.split()
	
				indx = splits.index( "Time" )
				duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				time_point += duration
				timestamp.append( time_point/timescale )

				batch_point += step
				batchstamp.append( batch_point/1000 )
	
				indx = splits.index( "Loss" )
				loss_train.append( float( splits[ indx+1 ] ) )
				val = float(splits[indx + 2][1:-1])
				loss_train_avg.append( val )
	
				indx = splits.index( "Acc@1" )
				acc1_train.append( float( splits[ indx+1 ] ) )
				acc1_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
				indx = splits.index( "Acc@5" )
				acc5_train.append( float( splits[ indx+1 ] ) )
				acc5_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
			if "Test" in line and start_train:
				splits = line.split()

				indx = splits.index( "Time" )
				duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				time_point += duration
	
				timestamp_test.append( time_point/timescale )
				epochstamp_test.append( epoch_point )
				if "4/4" in line:
					epoch_point += 1
					timestamp_test_avg.append( timestamp[ -1 ] )
					epochstamp_test_avg.append( epoch_point-1 )
	
				indx = splits.index( "Loss" )
				loss_test.append( float( splits[ indx+1 ] ) )
				if "4/4" in line:
					loss_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
				indx = splits.index( "Acc@1" )
				acc1_test.append( float( splits[ indx+1 ] ) )
				val  = float( splits[ indx+2 ][ 1:-1 ] )
				if "4/4" in line:
					acc1_test_avg.append( val )
	
				indx = splits.index( "Acc@5" )
				acc5_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )

			cnt += 1

		if average:
			w = 5
			acc1_test_avg = np.convolve(np.array(acc1_test_avg), np.ones(w), 'valid') / w
		# Train loss curve
		if realtime:
			stamp = timestamp
			stamp_test = timestamp_test
			stamp_test_avg = timestamp_test_avg
		else:
			stamp = batchstamp
			stamp_test = epochstamp_test
			stamp_test_avg = epochstamp_test_avg

		# plt.figure( figsize=(10, 6))
		# hold = True
		# plt.subplot( 2,2,1 )
		# plt.plot( stamp, loss_train, clr, alpha=0.2 )
		ax1.plot( stamp, loss_train_avg, clr+"-", linewidth=2 )

		ax1.tick_params(axis='x', labelsize=20)
		ax1.set_yscale( "log" )
		ax1.set_yticks([0.1, 0.2, 1])
		ax1.set_xticks([0, 1, 2, 3, 4, 5, 6])
		ax1.set_xticklabels(['0', '1k', '2k', '3k', '4k', '5k', '6k'])
		ax1.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
		ax1.get_yaxis().set_minor_formatter(matplotlib.ticker.NullFormatter())
		ax1.tick_params(axis='y', labelsize=20)
		ax1.set_xlabel( "Iterations", **csfont )
		ax1.set_ylabel("Training loss", **csfont)
		ax1.legend( ['SGD', 'SLIM-QN'],
					fontsize=ftsize )
		# plt.savefig( "log_loss.png" )
		
		# Test acc1 curve
		# plt.subplot( 2,2,2 )
		# plt.plot( stamp_test, acc1_test, clr, alpha=0.2 )
		ax2.plot( stamp_test_avg[0:1-w], np.array(acc1_test_avg), clr+"-", linewidth=2 )

		ax2.tick_params(axis='x', labelsize=20)
		ax2.tick_params(axis='y', labelsize=20)
		ax2.set_ylabel( "Validation acc.(%)", **csfont )
		ax2.set_xlabel( "Epochs", **csfont )
		ax2.legend( [ 'SGD:87.3%', 'SLIM-QN: 87.4%' ],
					fontsize=ftsize )

plt.show()