#%%
from matplotlib import pyplot as plt
import numpy as np
from utils import postprocess_many
import os
import config as co


# using latex font
plt.rcParams.update({"font.family":"serif", "font.size": 10,
    "pgf.texsystem": "pdflatex",

})
plt.rc('text', usetex=True)

results_dir = "./results/"
file_list = os.listdir(results_dir)
#%%


#%% 
agents_to_exclude = [
                    "LinUcbOracle",
                    "NL_alpha001",
                    "NL_alpha1",
                    "NL_alpha10",
                   ] # due to particular mistakes, not general
for exp_dict in [co.exp_proto]:#co.fig_a, co.fig_b, co.fig_c, co.fig_d]:
    prefix = f"u{exp_dict["n_users"]}d{exp_dict["dim"]}"+\
                   f"h{exp_dict["horizon"]}c{exp_dict["n_clusters"]}"+\
                        f"i{exp_dict["imbalance"]}p{exp_dict["p"]}q{exp_dict["q"]}"
    # list the files corresponding the the experiment (possibly differing by agents)
    exp_file_list = [os.path.join(results_dir,f) for f in file_list if f.startswith(prefix)]

    rewards, rewards_oracle, agent_names = postprocess_many(exp_file_list)
    regret_raw = rewards_oracle - rewards
    regret_cumul = np.cumsum(regret_raw, axis= 2)
    regret_cumul_dict = dict(zip(agent_names, regret_cumul))
    regret_cumul_dict = dict(sorted(regret_cumul_dict.items(), key= lambda x: x[0]))

    #%
    plt.figure()
    
    for agent_name, regret in regret_cumul_dict.items():
        if not (agent_name in agents_to_exclude):
            agent_name = agent_name.replace("NL", "NetLasso")
            agent_name = agent_name.replace("deg", "")
            agent_name = agent_name.replace("_alpha01","")
            agent_name = agent_name.replace("_corrected","")
            if agent_name == "OLS": agent_name = "OLS-ITL"
            reg_mean = np.mean(regret, axis= 0)
            reg_std = np.std(regret, axis= 0)
            plt.plot(reg_mean, label= agent_name)
            plt.fill_between(np.arange(len(reg_mean)), reg_mean + reg_std, reg_mean - reg_std,
                                alpha= 0.2)
    # plt.show()
            plt.legend()
            plt.tight_layout()
            plt.xlabel("time steps")
            plt.ylabel("cumulative regret")
            for format in ["pdf"]:
                plt.savefig(f"./figures/{prefix}.{format}", format= format)

# %%
