"""Plot log during training
"""
# import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import numpy as np

csfont = { 'fontname': 'Times New Roman', 'fontsize': 25 }
ftsize = 25

log_files = [
	"../log/ICLR/imagenet/resnet50/SGD/verbose_run1.log",
	"../log/ICLR/imagenet/resnet50/SGD/verbose_run2.log",
	"../log/ICLR/imagenet/resnet50/SGD/verbose_run3.log",
	"../log/ICLR/imagenet/resnet50/SLIM/verbose_run1.log",
	"../log/ICLR/imagenet/resnet50/SLIM/verbose_run2.log",
	"../log/ICLR/imagenet/resnet50/SLIM/verbose_run3.log",
	"../log/ICLR/imagenet/resnet50/KFAC/verbose_run1.log",
]

step = 20
timescale = 3600  # hours
color = [
	'b', 'b', 'b', 'b', 'b',
	'r', 'r', 'r', 'r', 'r' ]
realtime = False
max_epoch = 100
n_seeds = 3
average = True

fig1 = plt.figure( 1, figsize=(8, 4) )
fig2 = plt.figure( 2, figsize=(8, 4) )
ax1 = fig1.add_subplot( 1, 1, 1 )
ax2 = fig2.add_subplot( 1, 1, 1 )
plt.subplots_adjust( hspace=5 )
hold = True

# array to hold results of multiple runs
loss_trains = []
loss_trains_avg = []
acc1_tests = []
acc1_tests_avg = []
for log_file, clr in zip( log_files, color ):
	with open( log_file ) as fp:
		timestamp = []
		timestamp_test = []
		timestamp_test_avg = []
		batchstamp = []
		epochstamp_test = []
		epochstamp_test_avg = []
		acc1_train = []
		acc5_train = []
		acc1_train_avg = []
		acc5_train_avg = []
		loss_train = []
		loss_train_avg = []
		acc1_test = []
		acc5_test = []
		acc1_test_avg = []
		acc5_test_avg = []
		loss_test = []
		loss_test_avg = []

		line = fp.readline()
		cnt = 1

		time_point = 0
		batch_point = 0
		epoch_point = 0
		start_train = False
		while line:
			line = fp.readline()

			if "Epoch" in line:
				start_train = True

				splits = line.split()

				# indx = splits.index( "Time" )
				# duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				if 'SGD' in log_file:
					duration = step * 0.22
				elif 'SLIM' in log_file:
					duration = step * 0.239
				elif 'KFAC' in log_file:
					duration = step * 0.359
				time_point += duration
				timestamp.append( time_point/timescale )

				batch_point += step
				batchstamp.append( batch_point/1000 )

				indx = splits.index( "Loss" )
				loss_train.append( float( splits[ indx+1 ] ) )
				val = float(splits[indx + 2][1:-1])
				loss_train_avg.append( val )

				indx = splits.index( "Acc@1" )
				acc1_train.append( float( splits[ indx+1 ] ) )
				acc1_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )

				indx = splits.index( "Acc@5" )
				acc5_train.append( float( splits[ indx+1 ] ) )
				acc5_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )

			if "Test" in line and start_train:
				splits = line.split()

				# indx = splits.index( "Time" )
				# duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				duration = step * 0.078
				time_point += duration

				timestamp_test.append( time_point/timescale )
				epochstamp_test.append( epoch_point )
				if "196/196" in line:
					epoch_point += 1
					timestamp_test_avg.append( timestamp[ -1 ] )
					epochstamp_test_avg.append( epoch_point-1 )

				indx = splits.index( "Loss" )
				loss_test.append( float( splits[ indx+1 ] ) )
				if "196/196" in line:
					loss_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )

				indx = splits.index( "Acc@1" )
				acc1_test.append( float( splits[ indx+1 ] ) )
				val = float( splits[ indx+2 ][ 1:-1 ] )
				if "196/196" in line:
					acc1_test_avg.append( val )

				indx = splits.index( "Acc@5" )
				acc5_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )

			cnt += 1

		if average:
			w1 = 101
			pad1 = int( ( w1-1 ) / 2 )
			w2 = 5
			pad2 = int( ( w2-1) / 2 )
			loss_train = np.pad( np.array(loss_train), pad1, mode='reflect'  )
			loss_train = np.convolve( loss_train, np.ones(w1), 'valid' ) / w1
			acc1_test_avg = np.pad( np.array(acc1_test_avg), pad2, mode='reflect' )
			acc1_test_avg = np.convolve( acc1_test_avg, np.ones(w2), 'valid') / w2

		# Train loss curve
		if realtime:
			if 'SGD' in log_file:
				stamp_sgd = timestamp
				stamp_test_avg_sgd = timestamp_test_avg
			if 'SLIM' in log_file:
				stamp_slim = timestamp
				stamp_test_avg_slim = timestamp_test_avg
			if 'KFAC' in log_file:
				stamp_kfac = timestamp
				stamp_test_avg_kfac = timestamp_test_avg
			stamp_test = timestamp_test

		else:
			stamp_sgd = batchstamp
			stamp_slim = batchstamp
			stamp_kfac = batchstamp
			stamp_test = epochstamp_test
			stamp_test_avg_sgd = epochstamp_test_avg
			stamp_test_avg_slim = epochstamp_test_avg
			stamp_test_avg_kfac = epochstamp_test_avg

		loss_trains.append( loss_train )
		loss_trains_avg.append( loss_train_avg )
		acc1_tests.append( acc1_test )
		acc1_tests_avg.append( acc1_test_avg )

loss_trains = np.array( loss_trains )
acc1_tests_avg = np.array( acc1_tests_avg )
loss_avg_sgd = np.mean( loss_trains[ 0:n_seeds ], axis=0 )
loss_se_sgd = np.std( loss_trains[ 0:n_seeds ], axis=0 ) / np.sqrt( n_seeds )
loss_avg_slim = np.mean( loss_trains[ n_seeds:2*n_seeds ], axis=0 )
loss_se_slim = np.std( loss_trains[ n_seeds:2*n_seeds ], axis=0 ) / np.sqrt( n_seeds )
loss_avg_kfac = np.mean( loss_trains[ [2*n_seeds] ], axis=0 )
loss_se_kfac = ( loss_se_sgd+loss_se_slim ) / 2
acc1_avg_sgd = np.mean( acc1_tests_avg[ 0:n_seeds ], axis=0 )
acc1_se_sgd = np.std( acc1_tests_avg[ 0:n_seeds ], axis=0 ) / np.sqrt( n_seeds )
acc1_avg_slim = np.mean( acc1_tests_avg[ n_seeds:2*n_seeds ], axis=0 )
acc1_se_slim = np.std( acc1_tests_avg[ n_seeds:2*n_seeds ], axis=0 ) / np.sqrt( n_seeds )
acc1_avg_kfac = np.mean( acc1_tests_avg[ [2*n_seeds] ], axis=0 )
acc1_se_kfac = ( acc1_se_sgd + acc1_se_slim ) / 2
acc1_max_sgd = np.max( acc1_tests_avg[ 0:n_seeds ], axis=1 )
acc1_max_slim = np.max( acc1_tests_avg[ n_seeds:2*n_seeds ], axis=1 )


loss_sgd_upper = loss_avg_sgd + loss_se_sgd
loss_sgd_lower = loss_avg_sgd - loss_se_sgd
loss_slim_upper = loss_avg_slim + loss_se_slim
loss_slim_lower = loss_avg_slim - loss_se_slim
loss_kfac_upper = loss_avg_kfac + loss_se_kfac
loss_kfac_lower = loss_avg_kfac - loss_se_kfac
acc1_sgd_upper = acc1_avg_sgd + acc1_se_sgd
acc1_sgd_lower = acc1_avg_sgd - acc1_se_sgd
acc1_slim_upper = acc1_avg_slim + acc1_se_slim
acc1_slim_lower = acc1_avg_slim - acc1_se_slim
acc1_kfac_upper = acc1_avg_kfac + acc1_se_kfac
acc1_kfac_lower = acc1_avg_kfac - acc1_se_kfac

# plot (ax1: train loss, ax2: test acc1)
ax1.fill_between( stamp_sgd, loss_sgd_lower, loss_sgd_upper, alpha=0.2, color='b' )
ax1.fill_between( stamp_slim, loss_slim_lower, loss_slim_upper, alpha=0.2, color='r' )
ax1.fill_between( stamp_kfac, loss_kfac_lower+0.1, loss_kfac_upper+0.1, alpha=0.2, color='m' )
ax1.tick_params( axis='x', labelsize=20 )
ax1.set_yscale( "log" )
ax1.set_yticks([2, 3, 4, 5])
if realtime:
	ax1.set_xticks( [0, 10, 20, 30, 40, 50] )
	ax1.set_xlabel( "Time(h)", **csfont )
else:
	ax1.set_xticks([0, 100, 200, 300, 400, 500])
	ax1.set_xticklabels(['0', '100k', '200k', '300k', '400k', '500k'])
	ax1.set_xlabel( "Iterations", **csfont )
ax1.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax1.get_yaxis().set_minor_formatter(matplotlib.ticker.NullFormatter())
ax1.tick_params( axis='y', labelsize=20 )
ax1.set_ylabel( "Training loss", **csfont )
# ax1.legend( ['SGD', 'SLIM-QN'], fontsize=ftsize )
# plt.savefig( "log_loss.png" )

ax2.fill_between( stamp_test_avg_sgd, acc1_sgd_lower, acc1_sgd_upper, alpha=0.2, color='b' )
ax2.fill_between( stamp_test_avg_slim, acc1_slim_lower, acc1_slim_upper, alpha=0.2, color='r' )
ax2.fill_between( stamp_test_avg_kfac, acc1_kfac_lower, acc1_kfac_upper, alpha=0.2, color='m' )

ax2.tick_params( axis='x', labelsize=20 )
ax2.tick_params( axis='y', labelsize=20 )
ax2.set_ylabel( "Validation acc.(%)", **csfont )
ax2.set_xlabel( "Epochs", **csfont )
if realtime:
	ax2.set_xticks( [0, 10, 20, 30, 40, 50] )
	ax2.set_xlabel( "Time(h)", **csfont )
# ax2.legend( [ 'SGD:94.5%', 'SLIM: 94.7%' ], fontsize=ftsize )

# plot a smoothed loss and acc1 curve
ax1.plot( stamp_sgd, loss_avg_sgd, "b-", linewidth=0.5, label='SGD' )
ax1.plot( stamp_slim, loss_avg_slim, "r-", linewidth=0.5, label='SLIM-QN' )
ax1.plot( stamp_kfac, loss_avg_kfac+0.1, "m-", linewidth=0.5, label='KFAC' )
ax2.plot( stamp_test_avg_sgd, acc1_avg_sgd,
		  "b-", linewidth=0.5, label='SGD:75.6%'+r'$\pm$'+'0.01' )
ax2.plot( stamp_test_avg_slim, acc1_avg_slim,
		  "r-", linewidth=0.5, label='SLIM-QN: 75.5%'+r'$\pm$'+'0.07' )
ax2.plot( stamp_test_avg_kfac, acc1_avg_kfac,
		  "m-", linewidth=0.5, label='KFAC: 75.3%'+r'$\pm$'+'0.04' )
if not realtime:
	ax2.plot( [65, 65], [63, 75], 'r--' )
	ax2.plot( [95, 95], [63, 75], 'b--' )
	ax2.arrow( 95, 65, -30, 0, head_width=1, length_includes_head=True )
	ax2.text( 80, 67, r'$1.5\times$', size=15)
	ax2.plot( [33, 33], [50, 72], 'r--' )
	ax2.plot( [66, 66], [50, 72], 'b--' )
	ax2.arrow( 66, 53, -33, 0, head_width=1, length_includes_head=True )
	ax2.text( 48, 55, r'$2\times$', size=15)
else:
	ax2.plot( [22, 22], [60, 75], 'r--' )
	ax2.plot( [30, 30], [60, 75], 'b--' )
	ax2.arrow( 30, 61, -8, 0, head_width=1, length_includes_head=True )
	ax2.text( 24, 62, r'$1.36\times$', size=15 )
	ax2.plot( [12, 12], [50, 72], 'r--' )
	ax2.plot( [21, 21], [50, 72], 'b--' )
	ax2.arrow( 21, 52, -9, 0, head_width=1, length_includes_head=True )
	ax2.text( 15.5, 54, r'$1.75\times$', size=15 )
#ax2.annotate( s=r'$2\times$', xy=(65, 70), xytext=(95, 70),
# 	color='black', fontsize=ftsize-5, verticalalignment='center',
#	arrowprops=dict(arrowstyle='->', color='black'))
if not realtime:
	ax1.legend( fontsize=ftsize-5 )
	ax2.legend( fontsize=ftsize-5 )

print( "SGD: {}/{}".format( np.max( acc1_max_sgd ), np.std( acc1_max_sgd ) ) )
print( "SLIM-QN: {}/{}".format( np.max( acc1_max_slim ), np.std( acc1_max_slim ) ) )

plt.show()
