import sys
import ujson as json
import re
import numpy as np
import string
from collections import Counter

def exact_answer(prediction):
    match = re.search(r"answer is\s*<([^>]+)>", prediction, re.IGNORECASE)
    if match:
        return match.groups()[-1]
    lines = prediction.strip().split('\n')
    return lines[-1].strip() if lines else ""

def normalize_answer(s):
    s = s.lower()
    
    s = re.sub(r'\b(a|an|the)\b', '', s)
    
    exclude = set(string.punctuation)
    s = ''.join(ch for ch in s if ch not in exclude)
    
    return ' '.join(s.split())


def exact_match_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)
    
    if not normalized_ground_truth:
        return 0.0  
    truth = "answer is " + normalized_ground_truth
    if normalized_ground_truth in normalized_prediction or truth in normalized_prediction:
        return 1.0
    else:
        return 0.0

def f1_score(prediction, ground_truth):
    """计算F1分数"""
    pred_tokens = normalize_answer(prediction).split()
    gt_tokens = normalize_answer(ground_truth).split()
    
    special_answers = ['yes', 'no', 'noanswer']
    if prediction.lower() in special_answers and prediction.lower() != ground_truth.lower():
        return 0.0
    if ground_truth.lower() in special_answers and prediction.lower() != ground_truth.lower():
        return 0.0
    
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0.0
    
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gt_tokens)
    return (2 * precision * recall) / (precision + recall)

def compute_score(model_output, ground_truth, extra_info=None, method="flexible"):
    model_answer = exact_answer(model_output)
    if model_answer is None:
        return 0.0

    acc_score=0
    if method == "strict":
        acc_score = exact_match_score(model_answer, ground_truth)
    elif method == "flexible":
        acc_score = f1_score(model_answer, ground_truth)
    # print("****************************************************")
    # print("model_output:" + model_answer)
    # print("ground_truth:" + ground_truth)
    # print("acc_score:", acc_score)
    # print("----------------------------------------------------")
    
    return acc_score

