import argparse
import json
import os

# GSoS (first)
GSOS_FIRST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-first-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-first-s0-gpt2/checkpoint-20000",
]

# GSoS (rand)
GSOS_RAND_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-rand-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-rand-s0-gpt2/checkpoint-20000",
]

# GSoS (last)
GSOS_LAST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000",
    "/home/user/train-countdown/stream-of-search/outputs/star1-final-last-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star2-final-last-s0-gpt2/checkpoint-20000",
]


def main(args):
    # GSoS (first)
    gsos_first_ckpt = GSOS_FIRST_CKPTS[args.star_iter]
    gsos_first_prefix = "losses_final_first_s0_train"
    gsos_first_loss_files = [
        f for f in os.listdir(gsos_first_ckpt) if f.startswith(gsos_first_prefix)
    ]
    gsos_first_loss_files = sorted(gsos_first_loss_files)
    gsos_first_base_losses = []
    gsos_first_new_losses = []
    for loss_file in gsos_first_loss_files:
        loss_file = os.path.join(gsos_first_ckpt, loss_file)
        with open(loss_file, "r") as f:
            losses = json.load(f)
            gsos_first_base_losses += losses["base_losses"]
            gsos_first_new_losses += losses["new_losses"]

    # GSoS (rand)
    gsos_rand_ckpt = GSOS_RAND_CKPTS[args.star_iter]
    gsos_rand_prefix = "losses_final_rand_s0_train"
    gsos_rand_loss_files = [
        f for f in os.listdir(gsos_rand_ckpt) if f.startswith(gsos_rand_prefix)
    ]
    gsos_rand_loss_files = sorted(gsos_rand_loss_files)
    gsos_rand_base_losses = []
    gsos_rand_new_losses = []
    for loss_file in gsos_rand_loss_files:
        loss_file = os.path.join(gsos_rand_ckpt, loss_file)
        with open(loss_file, "r") as f:
            losses = json.load(f)
            gsos_rand_base_losses += losses["base_losses"]
            gsos_rand_new_losses += losses["new_losses"]

    # GSoS (last)
    gsos_last_ckpt = GSOS_LAST_CKPTS[args.star_iter]
    gsos_last_prefix = "losses_final_last_s0_train"
    gsos_last_loss_files = [
        f for f in os.listdir(gsos_last_ckpt) if f.startswith(gsos_last_prefix)
    ]
    gsos_last_loss_files = sorted(gsos_last_loss_files)
    gsos_last_base_losses = []
    gsos_last_new_losses = []
    for loss_file in gsos_last_loss_files:
        loss_file = os.path.join(gsos_last_ckpt, loss_file)
        with open(loss_file, "r") as f:
            losses = json.load(f)
            gsos_last_base_losses += losses["base_losses"]
            gsos_last_new_losses += losses["new_losses"]

    # Save data to JSON
    data = {
        "GSoS (first)": {
            "base_losses": gsos_first_base_losses,
            "new_losses": gsos_first_new_losses,
        },
        "GSoS (rand)": {
            "base_losses": gsos_rand_base_losses,
            "new_losses": gsos_rand_new_losses,
        },
        "GSoS (last)": {
            "base_losses": gsos_last_base_losses,
            "new_losses": gsos_last_new_losses,
        },
    }
    data_file = f"gsos_loss_star{args.star_iter}.json"
    data_file = os.path.join("data", data_file)
    with open(data_file, "w") as f:
        json.dump(data, f, indent=4)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--star_iter", default=0, type=int)
    args = parser.parse_args()

    # Run main
    main(args)
