import os
import glob
import json
import pandas as pd
import matplotlib.pyplot as plt

cwd = "/Users/hfml/Documents/Ran/rambo_seeds"
# algo = "rirl"
algo = "birl_0"
# num_seeds = 5
max_epoch = 500
plot_history = True

fig_path = os.path.join(cwd, f"{algo}_fig")
if not os.path.exists(fig_path):
    os.makedirs(fig_path)

results = []
paths = glob.glob(os.path.join(cwd, f"{algo}", "*"))
print(paths)
for p in paths:
    # parse task name
    task = os.path.basename(p).replace("-v2", "").split("-")
    env_name = task[0]
    data_name = "-".join(task[1:])
    
    # parse task results
    result_paths = glob.glob(os.path.join(p, "*"))
    for result_path in result_paths:
        # print(result_path)
        with open(os.path.join(result_path, "config.json"), "rb") as f:
            config = json.load(f)
        seed = config["run_params"]["seed"]

        df = pd.read_csv(os.path.join(result_path, "history.csv"))

        if len(df) < max_epoch:
            continue

        score = df.loc[df["epoch"] == max_epoch]["evaluation/return-average"].values[0]
        # score = df.loc[df["epoch"] == 300]["evaluation/return-average-last-10-iter"].values[0]
        
        results.append({
            "env": env_name,
            "data": data_name,
            "seed": seed,
            "score": score
        })

        # plot learning curve
        if plot_history:
            fig, ax = plt.subplots(1, 1, figsize=(6, 4))
            ax.plot(df["epoch"], df["evaluation/return-average"])
            ax.set_title(f"{env_name}-{data_name}, seed={seed}")
            plt.tight_layout()
            fig.savefig(os.path.join(fig_path, f"{env_name}-{data_name}-{seed}.png"), dpi=100)
            plt.clf()
            plt.close()

df_results = pd.DataFrame(results).sort_values(by=["env", "data"]).reset_index(drop=True)
df_results_mean = df_results.groupby(["env", "data"], as_index=False).mean()
df_results_std = df_results.groupby(["env", "data"], as_index=False).std()

print("\nmean")
print(df_results_mean)
print("\nstd")
print(df_results_std)

df_results.to_csv(os.path.join(cwd, f"{algo}_results.csv"), index=False)