import argparse
import json
import os

# GSoS (first)
GSOS_FIRST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-first-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-first-s0-gpt2/checkpoint-20000",
]

# GSoS (rand)
GSOS_RAND_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-rand-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-rand-s0-gpt2/checkpoint-20000",
]

# GSoS (last)
GSOS_LAST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-last-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-last-s0-gpt2/checkpoint-20000",
]


def main(args):
    # GSoS (first)
    gsos_first_ckpt = GSOS_FIRST_CKPTS[args.star_iter]
    gsos_first_prefix = "lengths_final_first_s0_train"
    gsos_first_length_files = [
        f for f in os.listdir(gsos_first_ckpt) if f.startswith(gsos_first_prefix)
    ]
    gsos_first_length_files = sorted(gsos_first_length_files)
    gsos_first_base_lengths = []
    gsos_first_new_lengths = []
    for length_file in gsos_first_length_files:
        length_file = os.path.join(gsos_first_ckpt, length_file)
        with open(length_file, "r") as f:
            lengths = json.load(f)
            gsos_first_base_lengths += lengths["base_lengths"]
            gsos_first_new_lengths += lengths["new_lengths"]

    # GSoS (rand)
    gsos_rand_ckpt = GSOS_RAND_CKPTS[args.star_iter]
    gsos_rand_prefix = "lengths_final_rand_s0_train"
    gsos_rand_length_files = [
        f for f in os.listdir(gsos_rand_ckpt) if f.startswith(gsos_rand_prefix)
    ]
    gsos_rand_length_files = sorted(gsos_rand_length_files)
    gsos_rand_base_lengths = []
    gsos_rand_new_lengths = []
    for length_file in gsos_rand_length_files:
        length_file = os.path.join(gsos_rand_ckpt, length_file)
        with open(length_file, "r") as f:
            lengths = json.load(f)
            gsos_rand_base_lengths += lengths["base_lengths"]
            gsos_rand_new_lengths += lengths["new_lengths"]

    # GSoS (last)
    gsos_last_ckpt = GSOS_LAST_CKPTS[args.star_iter]
    gsos_last_prefix = "lengths_final_last_s0_train"
    gsos_last_length_files = [
        f for f in os.listdir(gsos_last_ckpt) if f.startswith(gsos_last_prefix)
    ]
    gsos_last_length_files = sorted(gsos_last_length_files)
    gsos_last_base_lengths = []
    gsos_last_new_lengths = []
    for length_file in gsos_last_length_files:
        length_file = os.path.join(gsos_last_ckpt, length_file)
        with open(length_file, "r") as f:
            lengths = json.load(f)
            gsos_last_base_lengths += lengths["base_lengths"]
            gsos_last_new_lengths += lengths["new_lengths"]

    # Save data to JSON
    data = {
        "GSoS (first)": {
            "base_lengths": gsos_first_base_lengths,
            "new_lengths": gsos_first_new_lengths,
        },
        "GSoS (rand)": {
            "base_lengths": gsos_rand_base_lengths,
            "new_lengths": gsos_rand_new_lengths,
        },
        "GSoS (last)": {
            "base_lengths": gsos_last_base_lengths,
            "new_lengths": gsos_last_new_lengths,
        },
    }
    data_file = f"gsos_length_star{args.star_iter}.json"
    data_file = os.path.join("data", data_file)
    with open(data_file, "w") as f:
        json.dump(data, f, indent=4)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--star_iter", default=0, type=int)
    args = parser.parse_args()

    # Run main
    main(args)
