import os

import pandas as pd

# SOS + STaR + PPO
SOS_STAR_PPO_CKPT = "/home/user/train-countdown/tril/outputs/hppo/240818_173641"

# SOS + PPO
SOS_PPO_CKPT = "/home/user/train-countdown/tril/outputs/hppo/240820_193933"


def main():
    # SOS + STaR + PPO
    sos_star_ppo_stat_file = os.path.join(SOS_STAR_PPO_CKPT, "rollout_info.jsonl")
    sos_star_ppo_stats = pd.read_json(sos_star_ppo_stat_file, lines=True)
    sos_star_ppo_kl_div = sos_star_ppo_stats["rollout_buffer/kl_div"]

    # SOS + PPO
    sos_ppo_stat_file = os.path.join(SOS_PPO_CKPT, "rollout_info.jsonl")
    sos_ppo_stats = pd.read_json(sos_ppo_stat_file, lines=True)
    sos_ppo_kl_div = sos_ppo_stats["rollout_buffer/kl_div"]

    # Save data to CSV
    data = {
        "sos_star_ppo_kl_div": sos_star_ppo_kl_div,
        "sos_ppo_kl_div": sos_ppo_kl_div,
    }
    df = pd.DataFrame(data)
    df.to_csv("data/sos_rl_kl_div.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
