import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import math
sns.set()
font = {'weight' : 'bold',
        'size'   : 25}

matplotlib.rc('font', **font)

method = ['log-barrier','log-barrier','log-barrier', 'closed-form','closed-form','closed-form']

dataset = ['inventory']


color = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3']

style = ['-', '-', '-', '-', '-', '-']

graph_T = ['inventory']

low = 0;
up = 8;

num_trial = [i for i in range(low, up)];
h = 0.25
p = 1
k_list = [101,301,501,101,301,501]
n = 10000
for ds in dataset:
        for graph_type in graph_T:
                res_path = './res_{}_{}'.format(ds, graph_type)
                os.makedirs('figs', exist_ok=True)
                # plt.figure()
                fig, ax = plt.subplots()
                for c, m, s, k in zip(color, method, style, k_list):

                        if (k == 101):
                                tag = '_fixed'
                        else:
                                tag = '_fixed'

                        file_name = '{}_{}_{}_{}_{}_inventory_ogd_0{}.npy'.format(k,h, p,m.lower(), n, tag)
                        regret_1 = np.load(os.path.join(res_path, file_name))

                        file_name = '{}_{}_{}_{}_{}_inventory_ogd_1{}.npy'.format(k,h, p,m.lower(), n, tag)
                        regret_2 = np.load(os.path.join(res_path, file_name))

                        print(regret_1)
                        print(regret_2)

                        reg = np.stack((regret_1,regret_2))

                        for i in num_trial:
                                if (i<=1):
                                        continue
                                else:
                                        file_name = '{}_{}_{}_{}_{}_inventory_ogd_{}{}.npy'.format(k,h, p,m.lower(), n, i, tag)
                                        try:
                                                regret = np.load(os.path.join(res_path, file_name))
                                                reg = np.vstack((reg, regret))
                                        except:
                                                break;

                        print(reg)

                        mean = np.mean(reg, axis=0)

                        print("mean:\n")
                        print(mean)
                        for i in range (n):
                                mean[i] = mean[i] / (i+1);

                        print("scaled mean:\n")
                        print(mean)



                        std = np.std(reg, axis=0)
                        for i in range(n):
                                std[i] = std[i] / (i+1);

                        
                        n = reg.shape[1]
                        print(std)
                        std_1 = mean-std;
                        std_2 = mean+std;

                        axx = np.array(range(100,n))
                        
                        if (m == 'closed-form'):
                                tag_plot = 'SquareCB.G'
                                ax.semilogx(axx, mean[100:n], 'k-', linewidth=1.0, color=c, label=tag_plot+', k={}'.format(k), linestyle=s)
                                ax.fill_between(axx, std_1[100:n], std_2[100:n], color=c, alpha=0.4)

                        else:
                                tag_plot = 'SquareCB' 
                                ax.semilogx(axx, mean[100:n], 'k-', linewidth=1.0, color=c, label=tag_plot+', k={}'.format(k), linestyle=s)
                                ax.fill_between(axx, std_1[100:n], std_2[100:n], color=c, alpha=0.4)

                        print("# of attempts: ", reg.shape[0])
                        print(m, mean[-1])

                plt.xlabel('Iterations')
                plt.ylabel('PV Loss')
                plt.title("Inventory graph: h={},p={}, {} trials".format(h,p,up-low))
                plt.legend()
                plt.savefig("./figs/{}_{}_{}_{}_{}_ogd_{}_{}_log_plot.png".format(h,p,ds, graph_type,n,up-low,tag), dpi=600, bbox_inches='tight')

