"""Plot log during training
"""
# import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import numpy as np

csfont = { 'fontname': 'Times New Roman', 'fontsize': 25 }
ftsize = 25

log_files = [
	"../log/ICLR/cifar10/resnet18/SGD/verbose_run1.log",
	"../log/ICLR/cifar10/resnet18/SGD/verbose_run2.log",
	"../log/ICLR/cifar10/resnet18/SGD/verbose_run3.log",
	"../log/ICLR/cifar10/resnet18/SGD/verbose_run4.log",
	"../log/ICLR/cifar10/resnet18/SGD/verbose_run5.log",
	"../log/ICLR/cifar10/resnet18/SLIM/verbose_run1.log",
	"../log/ICLR/cifar10/resnet18/SLIM/verbose_run2.log",
	"../log/ICLR/cifar10/resnet18/SLIM/verbose_run3.log",
	"../log/ICLR/cifar10/resnet18/SLIM/verbose_run4.log",
	"../log/ICLR/cifar10/resnet18/SLIM/verbose_run5.log",]

step = 20
timescale = 3600  # hours
color = [
	'b', 'b', 'b', 'b', 'b',
	'r', 'r', 'r', 'r', 'r' ]
realtime = False
max_epoch = 225
n_seeds = 5
average = True

fig1 = plt.figure( 1, figsize=(8, 4) )
fig2 = plt.figure( 2, figsize=(8, 4) )
ax1 = fig1.add_subplot( 1, 1, 1 )
ax2 = fig2.add_subplot( 1, 1, 1 )
plt.subplots_adjust( hspace=5 )
hold = True

# array to hold results of multiple runs
loss_trains = []
loss_trains_avg = []
acc1_tests = []
acc1_tests_avg = []
for log_file, clr in zip( log_files, color ):
	with open( log_file ) as fp:
		timestamp = []
		timestamp_test = []
		timestamp_test_avg = []
		batchstamp = []
		epochstamp_test = []
		epochstamp_test_avg = []
		acc1_train = []
		acc5_train = []
		acc1_train_avg = []
		acc5_train_avg = []
		loss_train = []
		loss_train_avg = []
		acc1_test = []
		acc5_test = []
		acc1_test_avg = []
		acc5_test_avg = []
		loss_test = []
		loss_test_avg = []

		line = fp.readline()
		cnt = 1
	
		time_point = 0
		batch_point = 0
		epoch_point = 0
		start_train = False
		while line:
			line = fp.readline()
	
			if "Epoch" in line:
				start_train = True
	
				splits = line.split()
	
				indx = splits.index( "Time" )
				duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				time_point += duration
				timestamp.append( time_point/timescale )

				batch_point += step
				batchstamp.append( batch_point/1000 )
	
				indx = splits.index( "Loss" )
				loss_train.append( float( splits[ indx+1 ] ) )
				val = float(splits[indx + 2][1:-1])
				loss_train_avg.append( val )
	
				indx = splits.index( "Acc@1" )
				acc1_train.append( float( splits[ indx+1 ] ) )
				acc1_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
				indx = splits.index( "Acc@5" )
				acc5_train.append( float( splits[ indx+1 ] ) )
				acc5_train_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
			if "Test" in line and start_train:
				splits = line.split()

				indx = splits.index( "Time" )
				duration = step * float( splits[ indx+2 ][ 1:-1 ] )
				time_point += duration
	
				timestamp_test.append( time_point/timescale )
				epochstamp_test.append( epoch_point )
				if "40/40" in line:
					epoch_point += 1
					timestamp_test_avg.append( timestamp[ -1 ] )
					epochstamp_test_avg.append( epoch_point-1 )
	
				indx = splits.index( "Loss" )
				loss_test.append( float( splits[ indx+1 ] ) )
				if "40/40" in line:
					loss_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
				indx = splits.index( "Acc@1" )
				acc1_test.append( float( splits[ indx+1 ] ) )
				val = float( splits[ indx+2 ][ 1:-1 ] )
				if "40/40" in line:
					acc1_test_avg.append( val )
	
				indx = splits.index( "Acc@5" )
				acc5_test_avg.append( float( splits[ indx+2 ][ 1:-1 ] ) )
	
			cnt += 1

		if average:
			w1 = 25
			pad1 = int( ( w1-1 ) / 2 )
			w2 = 5
			pad2 = int( ( w2-1) / 2 )
			loss_train = np.pad( np.array(loss_train), pad1, mode='reflect'  )
			loss_train = np.convolve( loss_train, np.ones(w1), 'valid' ) / w1
			acc1_test_avg = np.pad( np.array(acc1_test_avg), pad2, mode='reflect' )
			acc1_test_avg = np.convolve( acc1_test_avg, np.ones(w2), 'valid') / w2
	
		# Train loss curve
		if realtime:
			stamp = timestamp
			stamp_test = timestamp_test
			stamp_test_avg = timestamp_test_avg
		else:
			stamp = batchstamp
			stamp_test = epochstamp_test
			stamp_test_avg = epochstamp_test_avg

		loss_trains.append( loss_train )
		loss_trains_avg.append( loss_train_avg )
		acc1_tests.append( acc1_test )
		acc1_tests_avg.append( acc1_test_avg )

loss_trains = np.array( loss_trains )
acc1_tests_avg = np.array( acc1_tests_avg )
loss_avg_sgd = np.mean( loss_trains[ 0:n_seeds ], axis=0 )
loss_se_sgd = np.std( loss_trains[ 0:n_seeds ], axis=0 ) / np.sqrt( n_seeds )
loss_avg_slim = np.mean( loss_trains[ n_seeds:10 ], axis=0 )
loss_se_slim = np.std( loss_trains[ n_seeds:10 ], axis=0 ) / np.sqrt( n_seeds )
acc1_avg_sgd = np.mean( acc1_tests_avg[ 0:n_seeds ], axis=0 )
acc1_se_sgd = np.std( acc1_tests_avg[ 0:n_seeds ], axis=0 ) / np.sqrt( n_seeds )
acc1_avg_slim = np.mean( acc1_tests_avg[ n_seeds:10 ], axis=0 )
acc1_se_slim = np.std( acc1_tests_avg[ n_seeds:10 ], axis=0 ) / np.sqrt( n_seeds )
acc1_max_sgd = np.max( acc1_tests_avg[ 0:n_seeds ], axis=1 )
acc1_max_slim = np.max( acc1_tests_avg[ n_seeds:10 ], axis=1 )


loss_sgd_upper = loss_avg_sgd + loss_se_sgd
loss_sgd_lower = loss_avg_sgd - loss_se_sgd
loss_slim_upper = loss_avg_slim + loss_se_slim
loss_slim_lower = loss_avg_slim - loss_se_slim
acc1_sgd_upper = acc1_avg_sgd + acc1_se_sgd
acc1_sgd_lower = acc1_avg_sgd - acc1_se_sgd
acc1_slim_upper = acc1_avg_slim + acc1_se_slim
acc1_slim_lower = acc1_avg_slim - acc1_se_slim

# plot (ax1: train loss, ax2: test acc1)
ax1.fill_between( stamp, loss_sgd_lower, loss_sgd_upper, alpha=0.2, color='b' )
ax1.fill_between( stamp, loss_slim_lower, loss_slim_upper, alpha=0.2, color='r' )
ax1.tick_params( axis='x', labelsize=20 )
ax1.set_yscale( "log" )
# ax1.set_yticks([0, 0.1, 0.2, 0.4, 0.8, 2, 3])
if realtime:
	ax1.set_xticks( [0, 2, 4, 6, 8] )
	ax1.set_xlabel( "Hours", **csfont )
else:
	ax1.set_xticks( [0, 10, 20, 30, 40, 50] )
	ax1.set_xticklabels( ['0', '10k', '20k', '30k', '40k', '50k'] )
	ax1.set_xlabel( "Iterations", **csfont )
ax1.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax1.get_yaxis().set_minor_formatter(matplotlib.ticker.NullFormatter())
ax1.tick_params( axis='y', labelsize=20 )
ax1.set_ylabel( "Training loss", **csfont )
# ax1.legend( ['SGD', 'SLIM-QN'], fontsize=ftsize )
# plt.savefig( "log_loss.png" )

ax2.fill_between( stamp_test_avg, acc1_sgd_lower, acc1_sgd_upper, alpha=0.2, color='b' )
ax2.fill_between( stamp_test_avg, acc1_slim_lower, acc1_slim_upper, alpha=0.2, color='r' )

ax2.tick_params( axis='x', labelsize=20 )
ax2.tick_params( axis='y', labelsize=20 )
ax2.set_ylabel( "Validation acc.(%)", **csfont )
ax2.set_xlabel( "Epochs", **csfont )
# ax2.legend( [ 'SGD:94.5%', 'SLIM: 94.7%' ], fontsize=ftsize )

# plot a smoothed loss and acc1 curve
ax1.plot( stamp, loss_avg_sgd, color[0]+"-", linewidth=0.5, label='SGD' )
ax1.plot( stamp, loss_avg_slim, color[n_seeds]+"-", linewidth=0.5, label='SLIM-QN' )
ax2.plot( stamp_test_avg, acc1_avg_sgd,
		  color[0]+"-", linewidth=0.5, label='SGD:94.2%'+r'$\pm$'+'0.16' )
ax2.plot( stamp_test_avg, acc1_avg_slim,
		  color[n_seeds]+"-", linewidth=0.5, label='SLIM: 94.6%'+r'$\pm$'+'0.05' )
ax1.legend( fontsize=ftsize )
ax2.legend( fontsize=ftsize )

print( "SGD: {}/{}".format( np.max( acc1_max_sgd ), np.std( acc1_max_sgd ) ) )
print( "SLIM-QN: {}/{}".format( np.max( acc1_max_slim ), np.std( acc1_max_slim ) ) )

plt.show()
