import json
import random
import re
import string


def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def em_check(prediction, golden_answers):
    if isinstance(golden_answers, str):
        golden_answers = [golden_answers]
    normalized_prediction = normalize_answer(prediction)
    score = 0
    for golden_answer in golden_answers:
        golden_answer = normalize_answer(golden_answer)
        if golden_answer == normalized_prediction:
            score = 1
            break
    return score


def extract_solution(solution_str):
    """Extract the equation from the solution string."""
    answer_pattern = r"<answer>(.*?)</answer>"
    match = re.finditer(answer_pattern, solution_str, re.DOTALL)
    matches = list(match)

    if len(matches) < 1:
        return None

    return matches[-1].group(1).strip()


def format_judge(solution_str: str):
    """
    Validates the format of a solution string from an LLM, with improved
    detection for malformed tags and incorrect sequences.
    """
    processed_str = re.sub(
        r"<tool_response>.*?</tool_response>", "", solution_str, flags=re.DOTALL)
    processed_str = re.sub(r"^(user|assistant)\s*$", "",
                           processed_str, flags=re.MULTILINE).strip()

    if "[Client Error]" in processed_str:
        return False, "Validation failed: Contains '[Client Error]'."

    valid_block_pattern = re.compile(
        r"<(think|tool_call|answer)>(.*?)</\1>", re.DOTALL)
    text_with_leftovers = valid_block_pattern.sub("", processed_str)
    malformed_tag_pattern = re.compile(r"</?(think|tool_call|answer)>")
    if malformed_tag_pattern.search(text_with_leftovers):
        return False, "Validation failed: Malformed, nested, or mismatched tags detected."

    matches = list(valid_block_pattern.finditer(processed_str))

    if not matches:
        if processed_str:
            return False, "Validation failed: No valid <think>, <tool_call>, or <answer> blocks found."
        return True, "Valid format (empty)"

    for i, match in enumerate(matches):
        tag_name = match.group(1)
        content = match.group(2)

        if tag_name == "tool_call":
            if not content.strip():
                return False, f"Validation failed: Content of <{tag_name}> is empty."
            try:
                json.loads(content)
            except json.JSONDecodeError:
                return False, f"Validation failed: Content of <{tag_name}> is not valid JSON."

    return True, "Valid format"


def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0, extra_info=None):
    """The scoring function for exact match (EM), with integrated format validation."""
    try:
        is_valid_format, reason = format_judge(solution_str)
    except Exception as e:
        print(f"FORMAT JUDGE ERROR: {e} \nSOLUTION: {solution_str}")
        is_valid_format = True
        reason = str(e)

    if not is_valid_format:
        print(f"FORMAT_ERROR: {reason} \nSOLUTION: {solution_str}")
        return -1.0

    answer = extract_solution(solution_str=solution_str)

    if answer is None:
        return 0.0

    em_score = em_check(answer, ground_truth)
    return float(em_score)
