#%%
from matplotlib import pyplot as plt
import numpy as np
from utils import postprocess
import os

from mpl_toolkits.axes_grid1 import ImageGrid

plt.rcParams.update({"font.family":"serif", "font.size": 14,
    "pgf.texsystem": "pdflatex",
    # "pgf.preamble": [
    #      r"\usepackage[utf8x]{inputenc}",
    #      r"\usepackage[T1]{fontenc}",
    #      r"\usepackage{cmbright}",
    #      r"\usepackage{amsmath, amsfonts, amssymb, amstext, amsthm, bbm, mathtools}",
    #      ]
})
plt.rc('text', usetex=True)

#%% 
# file_name = "2024-04-02 15:50:40.204198.pkl"
# with open(f"./results/{file_name}", 'rb') as file:
#     res_dict = pickle.load(file)
results_dir = "./results/"
for filename in os.listdir(results_dir):
# Compute the cumulated rewards and the mean over all reptitions for all agents

    rewards, rewards_oracle, agent_names = postprocess(os.path.join(results_dir, filename))
    regret_raw = rewards_oracle - rewards
    regret_cumul = np.cumsum(regret_raw, axis= 2)
    regret_cumul_dict = dict(zip(agent_names, regret_cumul))
    regret_cumul_dict = dict(sorted(regret_cumul_dict.items(), key= lambda x: x[0]))

    #%
    plt.figure()
    for agent_name, regret in regret_cumul_dict.items():
        if not (agent_name.endswith("alpha001") or agent_name.endswith("alpha1") or agent_name.endswith("alpha10")):
            agent_name = agent_name.replace("NL", "NetLasso")
            agent_name = agent_name.replace("deg", "")
            agent_name = agent_name.replace("_alpha01","")
            reg_mean = np.mean(regret, axis= 0)
            reg_std = np.std(regret, axis= 0)
            plt.plot(reg_mean, label= agent_name)
            plt.fill_between(np.arange(len(reg_mean)), reg_mean + reg_std, reg_mean - reg_std,
                                alpha= 0.5)
    # plt.show()
            plt.legend()
            plt.tight_layout()
            for format in ["pdf"]:
                plt.savefig(f"./figures/{filename[:-4]}.{format}", format= format)

# %%
