import re
from typing import Dict, Optional
import logging
import math

def extract_boxed_content(text: str) -> Optional[str]:
    if not text: return None
    match = re.search(r"\\boxed\{(.*?)\}", text, re.DOTALL)
    if match and match.group(1) is not None:
        return match.group(1).strip()
    return None

def extract_first_number(text: str) -> Optional[float]:
    if text is None:
        return None
    text = str(text).strip()
    if not text:
        return None
    text_no_commas = text.replace(',', '')
    match = re.search(r'[-+]?\d*\.\d+|[-+]?\d+', text_no_commas)
    if match:
        number_str = match.group(0)
        try:
            value = float(number_str)
            return value
        except ValueError:
            return None
    else:
        return None

def compare_numbers(target_num: Optional[float],
                    prediction_num: Optional[float],
                    max_relative_change: float = 0.05) -> bool:
    if prediction_num is None or target_num is None or \
       math.isnan(prediction_num) or math.isnan(target_num):
        return False

    if abs(target_num) < 1e-9: 
        is_correct = abs(prediction_num) < 1e-9
    else:
        relative_change = abs(prediction_num - target_num) / abs(target_num)
        is_correct = relative_change <= max_relative_change
    return is_correct

def exact_string_match(target_str: str, prediction_str: str) -> bool:
    pred_clean = str(prediction_str).strip().lower() if prediction_str is not None else ""
    target_clean = str(target_str).strip().lower() if target_str is not None else ""

    pred_clean = re.sub(r'[.!?]$', '', pred_clean).strip()
    target_clean = re.sub(r'[.!?]$', '', target_clean).strip()

    pred_clean = re.sub(r'[^\w\s]', '', pred_clean)
    target_clean = re.sub(r'[^\w\s]', '', target_clean)

    is_correct = (pred_clean == target_clean)
    return is_correct


def calculate_qa_accuracy(predicted_answer: Optional[str],
                          ground_truth_answer: str,
                          numeric_tolerance: float = 0.01) -> float:
    if ground_truth_answer is None:
        return 0.0
    true_label_str = str(ground_truth_answer).strip()

    if predicted_answer is None or str(predicted_answer).strip() == "":
        return 0.0
    pred_extracted_str = str(predicted_answer).strip()

    is_correct = False
    eval_type = "Unknown"

    if re.search(r'\d', true_label_str):
        eval_type = "Numeric Comparison"
        logging.debug(f"QA Accuracy - Type: {eval_type}")

        target_num = extract_first_number(true_label_str)
        pred_num = extract_first_number(pred_extracted_str)
        
        if target_num is not None and pred_num is not None:
             is_correct = compare_numbers(target_num, pred_num, numeric_tolerance)
        else:
             is_correct = False 

    else: 
        eval_type = "Textual Comparison"
        logging.debug(f"QA Accuracy - Type: {eval_type}")
        is_correct = exact_string_match(true_label_str, pred_extracted_str)

    logging.debug(f"QA Accuracy - Result: Correct={is_correct} (GT='{true_label_str}', Pred='{pred_extracted_str}')")
    return 1.0 if is_correct else 0.0

def extract_tag_content(tag: str, text: str) -> Optional[str]:
    if not text: return None
    match = re.search(rf"<{tag}(?: [^>]*)?>(.*?)</{tag}>", text, re.DOTALL | re.IGNORECASE)
    if not match and tag.lower() == 'answer':
         match = re.search(r"<answer(?: [^>]*)?>(.*?)/answer>", text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match and match.group(1) is not None else None

def compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
    predict_str = predict_str.strip()
    ground_truth = ground_truth.strip()
    
    if "<answer>" in ground_truth:    
        ground_truth = extract_tag_content("answer", ground_truth)
    else:
        ground_truth = ground_truth

    format_match = bool(re.search(r"\\boxed\{.*?\}", predict_str, re.DOTALL))
    format_score = 1.0 if format_match else 0.0

    predicted_answer = None
    if format_match: 
        extracted = extract_boxed_content(predict_str)
        if extracted is not None and extracted != "":
             predicted_answer = extracted 

    accuracy_score = 0.0
    if ground_truth is not None:
        accuracy_score = calculate_qa_accuracy(predicted_answer, ground_truth)
    else:
        logging.debug(f"Cannot calculate QA Accuracy because GT Answer is empty.")


    overall_score = 0.8 * accuracy_score + 0.2 * format_score

    return {
        "overall": overall_score,
        "format": format_score,
        "accuracy": accuracy_score,
    }