import argparse
import re
import json
import numpy as np
from tqdm import tqdm
from typing import Optional

def is_correct_mmlu_pro(og_pred, gt):
    """
    Check if any of the patterns '$$X$$', '$$(X)$$', '$$(X$$', or '$$X)$$' exist in the given string.

    :param og_pred: Original predicted answer.
    :param gt: Ground truth answer.
    :return: True if any of the patterns are found, otherwise False.
    """
    # Escape the symbol to ensure it's safe for regex
    escaped_symbol = re.escape(gt)
    
    # Define the patterns with the given symbol
    patterns = [
        rf'\$\${escaped_symbol}\$\$', 
        rf'\$\$\({escaped_symbol}\)\$\$', 
        rf'\$\$\({escaped_symbol}\$\$', 
        rf'\$\${escaped_symbol}\)\$\$',
        rf'\$\$\\text{{\({escaped_symbol}\)}}\$\$',
        rf'\$\$\\text{{{escaped_symbol}}}\$\$',
        rf'The final answer is {escaped_symbol}',
        rf'The final answer is \({escaped_symbol}\)',
    ]
    
    # Check for matches in the string
    for pattern in patterns:
        if re.search(pattern, og_pred):
            return True
    return False


def parse_args():
    parser = argparse.ArgumentParser() 
    parser.add_argument("--answer_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="output/answer")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    args.output_dir = f"{args.output_dir}/{args.answer_path.split('/')[-2]}"

    # load answers
    with open(args.answer_path, "r") as f:
        answers = [json.loads(line) for line in f]

    # evaluate answers
    is_corrects = [
        is_correct_mmlu_pro(instance["pred_answer"], instance["gt_answer"])
        for instance in tqdm(answers)
    ]
    
    print(f"Accuracy: {np.mean(is_corrects)}")

    # save results
    print("Saving results...")
    for instance, is_correct in zip(answers, is_corrects):
        instance["is_correct"] = is_correct
        with open(f"{args.output_dir}/{args.answer_path.split('/')[-1].split('.jsonl')[0]}_eval.jsonl", "a") as f:
            f.write(json.dumps(instance) + "\n")
