import re
from mathruler.grader import extract_boxed_content

def check_format(predict: str) -> bool:
    pattern = r'^<think>(.*?)</think><answer>(.*?)</answer>$'
    match = re.fullmatch(pattern, predict, flags=re.DOTALL) 
    return match is not None


def compute_score(solution_str, ground_truth, format_score=0.2, score=1.0, nothink=False):
    rep_penalty = 0
    if nothink:
        answer_indices = solution_str.split(',')
        format_score = 0
    else:
        correct_format = check_format(solution_str)
        if not correct_format:
            return {'score': 0, 'acc_reward': 0, 'format_reward': 0, 'acc': 0}
        answer_text = solution_str.split("<answer>")[-1].split("</answer>")[0].strip()
        answer_indices = answer_text.split(',')
    try:
        answer_indices = [int(answer_index.strip()) for answer_index in answer_indices]
    except:
        answer_indices = []
    if answer_indices == ground_truth:
        return {'score': format_score+score+rep_penalty, 'acc_reward': score, 'format_reward': format_score, 'acc': 1}
    else:
        acc_reward = 0
        # partial correct
        if len(set(answer_indices)) == len(ground_truth):
            partial_correct_num = 0
            for i in range(len(ground_truth)):
                if answer_indices[i] == ground_truth[i]:
                    partial_correct_num += 1
            if len(ground_truth) > 4:
                acc_reward = partial_correct_num/len(ground_truth) * 0.2
        return {'score': format_score+acc_reward+rep_penalty, 'acc_reward': acc_reward, 'format_reward': format_score, 'acc': 0}