import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 25}

matplotlib.rc('font', **font)

method = ['log-barrier', 'closed-form']

dataset = ['rcv1_50full']

color = ['orange', 'blue']

style = ['-', '-.']

graph_T = ['bandit', 'full_info', 'robs_cops']
graph_map = {}
graph_map['bandit'] = 'Bandit Graph'
graph_map['full_info'] = 'Full Information Graph'
graph_map['robs_cops'] = 'Cops and Robbers Graph'

method_map = {}
method_map['log-barrier'] = 'Log-barrier'
method_map['closed-form'] = 'SquareCB.G'

method = 'closed-form'
color = ['#377eb8', '#4daf4a', '#984ea3']

# ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3']
style = ['-', '-.', ':']
for ds in dataset:
        os.makedirs('rcv1_figs', exist_ok=True)
        plt.figure()
        for c, graph, s in zip(color, graph_T, style):
                res_path = './res_{}_{}'.format(ds, graph)
                regret = np.load(os.path.join(res_path, '{}_loss.npy'.format(method.lower())))
                regret = regret.astype(np.float32)
                m, n = regret.shape
                for i in range(n):
                        for j in range(m):
                                regret[j, i] = regret[j, i] * 1.0 / (i+1)
                regret = regret[:, 100:]
                indices = []
                for i in range(100, n):
                        indices.append(i+1)
                mean, std = np.mean(regret, axis=0), np.std(regret, axis=0)
                
                plt.semilogx(indices, mean.tolist(), 'k-', linewidth=2.0, color=c, label=graph_map[graph], linestyle=s)
                plt.fill_between(range(100, n), mean-std, mean+std, color=c, alpha=0.4)
                print(method, graph, mean[-1], std[-1])

        plt.xlabel('Iterations', fontsize=16)
        plt.ylabel('PV Loss', fontsize=16)
        # plt.title(method_map[method])
        plt.legend()
        plt.savefig("./rcv1_figs/closed-form.png", dpi=600, bbox_inches='tight')

