import re
import sys

from tqdm import tqdm
import pdb

def find_answer(text):
    text = text.split('\n\nQ: ')[0]
    patterns = [
        r"\(([A-B])\)",
        r"([A-B])\)",
        r"\b([A-B])\b",
    ]
    for pattern in patterns:
        matches = re.findall(pattern, text, re.DOTALL)
        if matches:
            return matches[-1].strip()
    return None

def evaluate(preds, golds):
    
    correct = 0
    sum = 0
    total = len(preds)
    for i, (pred, gold) in tqdm(enumerate(zip(preds, golds)), total=len(preds)):
        ans_pred = find_answer(pred)
        ans = gold['final_answer']
        
        if ans_pred:
            if ans_pred == ans:
                correct += 1
            sum += 1
        else:
            pass
    print(f"{float(correct)} / {total}")
    result = float(correct) / total
    print("=========acc result:{}".format(result))
    return result