import json
import os

import numpy as np
import pandas as pd

# GSoS + PPO (token)
GSOS_PPO_TOK_DIR = "/home/user/train-countdown/tril/outputs/ppo/240825_231537"

# GSos + PPO (operation)
GSOS_PPO_OP_DIR = "/home/user/train-countdown/tril/outputs/hppo/240825_231537"


def main():
    epochs = np.arange(0, 210, 10)
    gsos_ppo_tok_value_mean_epoch = []
    gsos_ppo_tok_value_std_epoch = []
    gsos_ppo_op_value_mean_epoch = []
    gsos_ppo_op_value_std_epoch = []
    for epoch in epochs:
        # GSoS + PPO (token)
        gsos_ppo_tok_ckpt = os.path.join(GSOS_PPO_TOK_DIR, f"model_{epoch}")
        gsos_ppo_tok_prefix = "values_final_s0_train"
        gsos_ppo_tok_value_files = [
            f
            for f in os.listdir(gsos_ppo_tok_ckpt)
            if f.startswith(gsos_ppo_tok_prefix)
        ]
        gsos_ppo_tok_value_files = sorted(gsos_ppo_tok_value_files)
        gsos_ppo_tok_values = []
        for value_file in gsos_ppo_tok_value_files:
            value_file = os.path.join(gsos_ppo_tok_ckpt, value_file)
            with open(value_file, "r") as f:
                values = json.load(f)
                gsos_ppo_tok_values += values["values"]
        gsos_ppo_tok_value_mean = np.mean(gsos_ppo_tok_values)
        gsos_ppo_tok_value_std = np.std(gsos_ppo_tok_values)
        gsos_ppo_tok_value_mean_epoch.append(gsos_ppo_tok_value_mean)
        gsos_ppo_tok_value_std_epoch.append(gsos_ppo_tok_value_std)

        # GSoS + PPO (operation)
        gsos_ppo_op_ckpt = os.path.join(GSOS_PPO_OP_DIR, f"model_{epoch}")
        gsos_ppo_op_prefix = "values_final_s0_train"
        gsos_ppo_op_value_files = [
            f for f in os.listdir(gsos_ppo_op_ckpt) if f.startswith(gsos_ppo_op_prefix)
        ]
        gsos_ppo_op_value_files = sorted(gsos_ppo_op_value_files)
        gsos_ppo_op_values = []
        for value_file in gsos_ppo_op_value_files:
            value_file = os.path.join(gsos_ppo_op_ckpt, value_file)
            with open(value_file, "r") as f:
                values = json.load(f)
                gsos_ppo_op_values += values["values"]
        gsos_ppo_op_value_mean = np.mean(gsos_ppo_op_values)
        gsos_ppo_op_value_std = np.std(gsos_ppo_op_values)
        gsos_ppo_op_value_mean_epoch.append(gsos_ppo_op_value_mean)
        gsos_ppo_op_value_std_epoch.append(gsos_ppo_op_value_std)

    # Save data to CSV
    data = {
        "epoch": epochs,
        "gsos_ppo_tok_value_mean": gsos_ppo_tok_value_mean_epoch,
        "gsos_ppo_tok_value_std": gsos_ppo_tok_value_std_epoch,
        "gsos_ppo_op_value_mean": gsos_ppo_op_value_mean_epoch,
        "gsos_ppo_op_value_std": gsos_ppo_op_value_std,
    }
    df = pd.DataFrame(data)
    df.to_csv("data/gsos_rl_value.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
