import os
import argparse

import pandas as pd
import wandb
import plotext as plt

parser = argparse.ArgumentParser()
parser.add_argument("run_path", type=str, help="wandb unique run path (found under the run overview in wandb console)")
parser.add_argument("--output_dir", "-o", type=str, help="Output directory for the PPO outputs")
args = parser.parse_args()

api = wandb.Api()
df = pd.DataFrame()

run = api.run(args.run_path)

for artifact in run.logged_artifacts():
    table = artifact.get("samples")
    if table:
        table_df = pd.DataFrame(data=table.data, columns=table.columns)
        table_df["version"] = artifact.version
        df = pd.concat([df, table_df])

os.makedirs(args.output_dir, exist_ok=True)
df.to_json(os.path.join(args.output_dir, "outputs.json"), orient="records")

df[["reward", "sample_kl"]].hist(bins=25, figsize=(12, 4))
plt.hist(df["sample_kl"], bins=100, label="Sample output KL distribution")
plt.xlabel("KL")
plt.show()
plt.clear_figure()
grouped_df = df[["reward", "sample_kl"]].groupby(pd.qcut(df["sample_kl"], 25)).mean()
plt.plot(grouped_df["sample_kl"], grouped_df["reward"])
plt.xlabel("KL")
plt.ylabel("RM Reward")
plt.show()
