import os
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
sns.set()
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 25}

matplotlib.rc('font', **font)
os.makedirs('rcv1_figs', exist_ok=True)
plt.figure()
start_iter = 500

def plot_loss(loss, color, label, linestyle):
	m, n = loss.shape
	loss = loss.astype(np.float32)
	for i in range(n):
		for j in range(m):
			loss[j, i] = loss[j, i] * 1.0 / (i+1)
	loss = loss[:, start_iter:]
	indices = []
	for i in range(start_iter, n):
		indices.append(i+1)
	mean, std = np.mean(loss, axis=0), np.std(loss, axis=0)

	plt.semilogx(indices, mean.tolist(), 'k-', linewidth=2.0, color=color, label=label, linestyle=linestyle)
	plt.fill_between(range(start_iter, n), mean-std, mean+std, color=color, alpha=0.4)
	print(label, mean[-1], std[-1])


path = './new_res_vw_rcv1_10full_random/graph_loss.npy'
graph_loss = np.load(path)
plot_loss(graph_loss, '#377eb8', 'SquareCB.G', '-')

log_loss = np.load('./new_res_vw_rcv1_10full_random/squarecb_loss.npy')
plot_loss(log_loss, '#984ea3', 'SquareCB', ':')

plt.xlabel('Iterations', fontsize=16)
plt.ylabel('PV Loss', fontsize=16)
plt.legend()
plt.savefig("./rcv1_figs/squarecb_graph.png", dpi=600, bbox_inches='tight')