import pathlib
import json
from tqdm import tqdm

from olym_gen.utils.utils import retrieve_id_from_name


def main(save_dir: str, question_name: str, proof_name: str):
    wrong_index = []
    problem_dict = {}
    problem_proof_dict = {}

    print(f"Checking files in {save_dir} ...")
    for file in tqdm(pathlib.Path(save_dir).rglob("*.json")):
        problem_index, proof_index, generation_index = retrieve_id_from_name(file.name)
        data = json.load(open(file, "r", encoding="utf-8"))
        question = data[question_name]
        proof = data[proof_name]
        if problem_index not in problem_dict:
            problem_dict[problem_index] = question
        else:
            if problem_dict[problem_index] != question:
                wrong_index.append(file)
        if (problem_index, proof_index) not in problem_proof_dict:
            problem_proof_dict[(problem_index, proof_index)] = proof
        else:
            if problem_proof_dict[(problem_index, proof_index)] != proof:
                wrong_index.append(file)
    if len(wrong_index) > 0:
        print(f"Found {len(wrong_index)} wrong files:")
        for file in wrong_index:
            print(file)
    else:
        print(f"All Pass!")
    
def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("save_dir", type=str, help="The directory to save the json files.")
    parser.add_argument("question_name", type=str, help="The key name for question in json file.")
    parser.add_argument("proof_name", type=str, help="The key name for proof in json file.")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    main(args.save_dir, args.question_name, args.proof_name)