import argparse
import json
import os

import numpy as np

# GSoS (first)
GSOS_FIRST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-first-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-first-s0-gpt2/checkpoint-20000",
]

# GSoS (rand)
GSOS_RAND_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-rand-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-rand-s0-gpt2/checkpoint-20000",
]

# GSoS (last)
GSOS_LAST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-last-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-last-s0-gpt2/checkpoint-20000",
]


def main():
    gsos_first_ratios = []
    gsos_rand_ratios = []
    gsos_last_ratios = []
    star_iters = [0, 1, 2]
    for star_iter in star_iters:
        # GSoS (first)
        gsos_first_ckpt = GSOS_FIRST_CKPTS[star_iter]
        gsos_first_prefix = "results_final_first_s0_depth2_train"
        gsos_first_result_files = [
            f for f in os.listdir(gsos_first_ckpt) if f.startswith(gsos_first_prefix)
        ]
        gsos_first_result_files = sorted(gsos_first_result_files)
        gsos_first_ratings = []
        for result_file in gsos_first_result_files:
            result_file = os.path.join(gsos_first_ckpt, result_file)
            with open(result_file, "r") as f:
                results = json.load(f)
                gsos_first_ratings += results["ratings"]
        gsos_first_ratio = np.mean(np.array(gsos_first_ratings) > 0.0)
        gsos_first_ratios.append(gsos_first_ratio)

        # GSoS (rand)
        gsos_rand_ckpt = GSOS_RAND_CKPTS[star_iter]
        gsos_rand_prefix = "results_final_rand_s0_depth2_train"
        gsos_rand_result_files = [
            f for f in os.listdir(gsos_rand_ckpt) if f.startswith(gsos_rand_prefix)
        ]
        gsos_rand_result_files = sorted(gsos_rand_result_files)
        gsos_rand_ratings = []
        for result_file in gsos_rand_result_files:
            result_file = os.path.join(gsos_rand_ckpt, result_file)
            with open(result_file, "r") as f:
                results = json.load(f)
                gsos_rand_ratings += results["ratings"]
        gsos_rand_ratio = np.mean(np.array(gsos_rand_ratings) > 0.0)
        gsos_rand_ratios.append(gsos_rand_ratio)

        # GSoS (last)
        gsos_last_ckpt = GSOS_LAST_CKPTS[star_iter]
        gsos_last_prefix = "results_final_last_s0_depth2_train"
        gsos_last_result_files = [
            f for f in os.listdir(gsos_last_ckpt) if f.startswith(gsos_last_prefix)
        ]
        gsos_last_result_files = sorted(gsos_last_result_files)
        gsos_last_ratings = []
        for result_file in gsos_last_result_files:
            result_file = os.path.join(gsos_last_ckpt, result_file)
            with open(result_file, "r") as f:
                results = json.load(f)
                gsos_last_ratings += results["ratings"]
        gsos_last_ratio = np.mean(np.array(gsos_last_ratings) > 0.0)
        gsos_last_ratios.append(gsos_last_ratio)

    # Save data to JSON
    data = {
        "star_iters": star_iters,
        "gsos_first_ratios": gsos_first_ratios,
        "gsos_rand_ratios": gsos_rand_ratios,
        "gsos_last_ratios": gsos_last_ratios,
    }
    data_file = f"gsos_ratio.json"
    data_file = os.path.join("data", data_file)
    with open(data_file, "w") as f:
        json.dump(data, f, indent=4)


if __name__ == "__main__":
    # Run main
    main()
