
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt 
import os, sys

if len(sys.argv) < 2:
	print('Usage: <script> <path>')
	exit()
path = sys.argv[1]
assert os.path.exists(path)

fnames = []
fnames.append('10.00-0.00-0.00.npy')
fnames.append('10.00-0.01-0.00.npy')
fnames.append('10.00-0.00-1.00.npy')
fnames.append('10.00-0.01-1.00.npy')

colors = iter(['green','gold', 'blue', 'red'])
markers = iter(['o','v','^','d'])
fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
font = {'size':18}
ax.set_xlim(0, 20)
ax.set_ylim(0.4, 1.0)
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
for fname in fnames:
	fpath = os.path.join(path, fname)
	if os.path.isdir(fpath):
		continue
	if fname[-4:] != '.npy':
		continue
	print(fpath)
	data = np.load(fpath)[:, :, 0]
	name = fname[:-4]
	weights = [float(w) for w in name.split('-')]
	assert len(weights) == 3

	label = 'AE'
	if weights[1] > 0:
		label += '+CLUB'
	if weights[2] > 0:
		label += '($d_z=5$)'
	else:
		label += '($d_z=7$)'
	x = np.arange(data.shape[1])
	mu = data.mean(axis=0)
	min_v = data.min(axis=0)
	max_v = data.max(axis=0)
	std = data.std(axis=0, ddof=1) / np.sqrt(data.shape[0])
	c = next(colors)
	m = next(markers)
	ax.plot(x, mu, color=c, marker=m, label=label)
	ax.fill_between(x, min_v, max_v, alpha=0.2, color=c)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.grid(True, axis='y')
plt.xlabel('epochs', font)
plt.ylabel('$R^2$ score', font)
plt.legend(loc='lower right', fontsize=18)
plt.savefig('all.png' % weights)
