import json
import os

import pandas as pd

GPT2_CKPT = "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000"
GPT_NEO_CKPT = "/home/user/train-countdown/stream-of-search/outputs/sft-gpt-neo/checkpoint-50000"


def main():
    # Load GPT-2 stats
    gpt2_stat_file = os.path.join(GPT2_CKPT, "trainer_state.json")
    with open(gpt2_stat_file, "r") as f:
        gpt2_stats = json.load(f)
    gpt2_stats = gpt2_stats["log_history"]

    # Get GPT-2 steps and losses
    gpt2_val_steps = []
    gpt2_val_losses = []
    gpt2_test_steps = []
    gpt2_test_losses = []
    for stat in gpt2_stats:
        if "eval_valid_loss" in stat:
            gpt2_val_steps.append(stat["step"])
            gpt2_val_losses.append(stat["eval_valid_loss"])
        elif "eval_valid_target_loss" in stat:
            gpt2_test_steps.append(stat["step"])
            gpt2_test_losses.append(stat["eval_valid_target_loss"])

    # Load GPT-Neo stats
    gpt_neo_stat_file = os.path.join(GPT_NEO_CKPT, "trainer_state.json")
    with open(gpt_neo_stat_file, "r") as f:
        gpt_neo_stats = json.load(f)
    gpt_neo_stats = gpt_neo_stats["log_history"]

    # Get GPT-2 steps and losses
    gpt_neo_val_steps = []
    gpt_neo_val_losses = []
    gpt_neo_test_steps = []
    gpt_neo_test_losses = []
    for stat in gpt_neo_stats:
        if "eval_valid_loss" in stat:
            gpt_neo_val_steps.append(stat["step"])
            gpt_neo_val_losses.append(stat["eval_valid_loss"])
        elif "eval_valid_target_loss" in stat:
            gpt_neo_test_steps.append(stat["step"])
            gpt_neo_test_losses.append(stat["eval_valid_target_loss"])

    # Save data to CSV
    data = {
        "gpt2_val_steps": gpt2_val_steps,
        "gpt2_val_losses": gpt2_val_losses,
        "gpt2_test_steps": gpt2_test_steps,
        "gpt2_test_losses": gpt2_test_losses,
        "gpt_neo_val_steps": gpt_neo_val_steps,
        "gpt_neo_val_losses": gpt_neo_val_losses,
        "gpt_neo_test_steps": gpt_neo_test_steps,
        "gpt_neo_test_losses": gpt_neo_test_losses,
    }
    df = pd.DataFrame(data)
    df.to_csv("data/spt_loss.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
