import json
import os

import numpy as np
import pandas as pd

SOS_CKPT = "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000"


def main():
    # Load results
    accuracy = []
    loss_mean = []
    loss_std = []
    depths = [0, 1, 2, 3]
    for depth in depths:
        prefix = f"results_final_optimal_s0_depth{depth}_train"
        files = sorted([f for f in os.listdir(SOS_CKPT) if f.startswith(prefix)])
        ratings = []
        losses = []
        for file in files:
            file = os.path.join(SOS_CKPT, file)
            with open(file, "r") as f:
                results = json.load(f)
                ratings += results["ratings"]
                losses += results["losses"]
        accuracy.append(np.mean([r > 0 for r in ratings]))
        losses = [losses[i] for i in range(len(losses)) if ratings[i] > 0]
        loss_mean.append(np.mean(losses))
        loss_std.append(np.std(losses))

    # Save data
    data = {
        "depth": depths,
        "accuracy": accuracy,
        "loss_mean": loss_mean,
        "loss_std": loss_std,
    }
    df = pd.DataFrame(data)
    df.to_csv("data/sos_opt.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
