
import sys
import ujson as json
import re
import numpy as np
import string
from collections import Counter

def exact_answer(prediction):
    match = re.search(r"answer is\s*<([^>]+)>", prediction, re.IGNORECASE)
    if match:
        return match.groups()[-1]
    lines = prediction.strip().split('\n')
    return lines[-1].strip() if lines else ""

def normalize_answer(s):
    s = s.lower()
    
    s = re.sub(r'\b(a|an|the)\b', '', s)
    
    exclude = set(string.punctuation)
    s = ''.join(ch for ch in s if ch not in exclude)
    
    return ' '.join(s.split())

def exact_match_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)
    
    if not normalized_ground_truth:
    truth = "answer is " + normalized_ground_truth
    if normalized_ground_truth in normalized_prediction or truth in normalized_prediction:
        return 1.0
    else:
        return 0.0



def compute_score(model_output, ground_truth, extra_info=None):
    model_answer = exact_answer(model_output)
    if model_answer is None:
        return 0.0

    acc_score=0
    acc_score = exact_match_score(model_answer, ground_truth)
    # print("****************************************************")
    # print("model_output:" + model_answer)
    # print("ground_truth:" + ground_truth)
    # print("acc_score:", acc_score)
    # print("----------------------------------------------------")
    return acc_score

