import argparse
import json
import os
instruction = "Complete the following problem."

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract correct proofs from data file using evaluation results.")
    parser.add_argument("--generation_file", type=str, help="Path to the data JSON file")
    parser.add_argument("--eval_file", type=str, help="Path to the evaluation result JSON file")
    parser.add_argument("--output_file", type=str, help="")
    args = parser.parse_args()
    
    if args.eval_file is None:
        base, ext = os.path.splitext(args.generation_file)
        args.eval_file = base + "_eval.json"
    

    with open(args.generation_file, "r") as f:
        data = json.load(f)

    with open(args.eval_file, "r") as f:
        eval_results = json.load(f)["results"]

    alpaca_dataset = []

    for item in data:
        problem_id = item["problem_id"]
        header = item["header"]
        formal_statement = item["formal_statement"]
        proofs = item["proof"]
        passed_list = eval_results[problem_id]["passed"]

        for i, passed in enumerate(passed_list):
            if passed:
                proof = proofs[i]
                alpaca_item = {
                    "instruction": instruction,
                    "input": header.strip() + "\n" + formal_statement.strip(),
                    "output": proof,
                }
                alpaca_dataset.append(alpaca_item)

    
    if args.output_file is None:
        args.output_file = args.eval_file.replace(".json", "_alpaca_sample.json")

    with open(args.output_file, "w") as f:
        json.dump(alpaca_dataset, f, indent=2, ensure_ascii=False)

    print(f"Got {len(alpaca_dataset)} Alpaca-formatted samples")