import wandb
import pandas as pd
import numpy as np
from os import listdir


if __name__ == "__main__":
    wandb.login(key="9de538ad4cfd43217c67c20db4f9f07d8037a642")
    api = wandb.Api()
    project_name = "abi/benchmark_toy_ircp_4"
    path = "data"
    runs = api.runs(project_name)
    summary_list = []
    config_list = []
    name_list = []
    for run in runs:
        # run.summary are the output key/values like accuracy.  We call ._json_dict to omit large files
        summary_list.append(run.summary._json_dict)

        # run.config is the input metrics.  We remove special values that start with _.
        config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})

        # run.name is the name of the run.
        name_list.append({"name": run.name, "id": run.id})

    summary_df = pd.DataFrame.from_records(summary_list)
    config_df = pd.DataFrame.from_records(config_list)
    name_df = pd.DataFrame(name_list)
    all_df = pd.concat([name_df, config_df, summary_df], axis=1)
    print("all_df shape")
    print(all_df.shape)

    returns = []
    for f in listdir(path):
        print("PATH: ", f)
        if ".csv" in f:
            r = pd.read_csv(path + "/" + f)  # , encoding="ISO-8859-1")
            returns.append(r)
        else:
            raise Exception("Need a csv file to process.") from None

    print("returns shape")
    returns = pd.concat(returns, ignore_index=True)
    print(returns.shape)
    metrics = []
    for index, row in all_df.iterrows():
        try:
            # r = returns.iloc[:, np.where('glowing-grass-573' + ' - eval_episode_return'==returns.columns)[0][0]]
            row_name = row["name"]
            return_name = "reward"
            r = returns.iloc[:, np.where(row_name + " - " + return_name == returns.columns)[0][0]]
            return_name = "epsilon"
            e = returns.iloc[:, np.where(row_name + " - " + return_name == returns.columns)[0][0]]
            return_name = "delta_goal"
            d = returns.iloc[:, np.where(row_name + " - " + return_name == returns.columns)[0][0]]
            df = pd.DataFrame({"step": returns["Step"], "return": r, "epsilon": e, "delta_goal": d})
            df["dims"] = row["env_dims"]
            df["method"] = row["algo"]
            df = df.dropna()
            metrics.append(df)
        except Exception as e:
            print(f'Exception for: {row["name"]}: {e}')

    c_metrics = pd.concat(metrics, ignore_index=True)
    print(c_metrics.shape)
    c_metrics.to_csv("metrics2.csv")
