import json
import os
import argparse
def load_json(file_path):
    results = []
    with open(file_path, "r") as f:
        for line in f:
            results.append(json.loads(line))
    return results

def eval_bigbenchhard_accuracy(jsonl_file):
    results = load_json(jsonl_file)
    correct = 0
    total = len(results)
    for result in results:
        correct_answer = result["correct_answer"]
        # if correct_answer (answer) return answer
        if correct_answer.startswith("("):
            correct_answer = correct_answer[1:-1]
        
        answer_code = ""
        try:
            answer_out =   result['attempt_answer']["output"]
            answer_code = result['attempt_answer']["code"]
        except Exception as e:
            # print(e)
            # print(result)
            answer_out = result["text"] if hasattr(result, "text") else result["attempt_answer"]["text"]
            answer_out = answer_out if isinstance(answer_out, str) else ""
        if correct_answer in answer_out or correct_answer in answer_code:
            correct += 1
    return correct, total

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--result_jsonl_file", type=str, required=True)
    args = parser.parse_args()
    jsonl_file = args.result_jsonl_file
    correct, total = eval_bigbenchhard_accuracy(jsonl_file)
    print(f"Accuracy: {correct / total * 100:.2f}% ({correct}/{total})")