import re
from typing import Optional, List
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def extract_final_number(text: str) -> Optional[float]:

    try:
        if '####' in text:
            after_hash = text.split('####')[-1].strip()
            number_match = re.search(r'-?\d*\.?\d+', after_hash)
            if number_match:
                return float(number_match.group())
        
        answer_match = re.search(r'(?:answer|result|therefore)[^\d]*(-?\d*\.?\d+)', text.lower())
        if answer_match:
            return float(answer_match.group(1))
        
        numbers = re.findall(r'-?\d*\.?\d+', text)
        if numbers:
            valid_numbers = [float(n) for n in numbers if abs(float(n)) < 1e10]
            if valid_numbers:
                return float(valid_numbers[-1])
            logger.warning(f"Found numbers but all were too large: {numbers}")
            return None
            
        logger.warning(f"No numbers found in text: {text}")
        return None
        
    except Exception as e:
        logger.error(f"Error extracting number from text: {text}. Error: {e}")
        return None

def normalize_answer(text: str) -> str:

    text = text.lower()
    
    text = re.sub(r'[^a-z0-9\s]', ' ', text)
    
    text = ' '.join(text.split())
    
    return text.strip()

def compute_exact_match(prediction: str, ground_truth: str) -> bool:

    try:
        pred_number = extract_final_number(prediction)
        truth_number = extract_final_number(ground_truth)
        if pred_number is not None and truth_number is not None:
            exact_match = abs(pred_number - truth_number) < 1e-6
            logger.debug(f"Numeric exact match: {exact_match} (Prediction: {pred_number}, Ground Truth: {truth_number})")
            return exact_match

        pred_norm = normalize_answer(prediction)
        truth_norm = normalize_answer(ground_truth)
        exact_match = pred_norm == truth_norm
        logger.debug(f"Text exact match: {exact_match}")
        return exact_match
        
    except Exception as e:
        logger.error(f"Error computing exact match: {e}")
        return False

def compute_f1(prediction: str, ground_truth: str) -> float:

    try:
        pred_number = extract_final_number(prediction)
        truth_number = extract_final_number(ground_truth)
        if pred_number is not None and truth_number is not None:
            f1_score = 1.0 if abs(pred_number - truth_number) < 1e-6 else 0.0
            logger.debug(f"Numeric F1 score: {f1_score}")
            return f1_score

        pred_tokens = set(normalize_answer(prediction).split())
        truth_tokens = set(normalize_answer(ground_truth).split())
        
        common = pred_tokens & truth_tokens
        if not common:
            logger.debug("No common tokens found")
            return 0.0
            
        precision = len(common) / len(pred_tokens)
        recall = len(common) / len(truth_tokens)
        
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        logger.debug(f"Text F1 score: {f1_score:.2f}")
        return f1_score
        
    except Exception as e:
        logger.error(f"Error computing F1 score: {e}")
        return 0.0

if __name__ == "__main__":
    prediction = "The final answer is 42."
    ground_truth = "42"

    logger.info("\n=== Exact Match Test ===")
    exact_match = compute_exact_match(prediction, ground_truth)
    logger.info(f"Exact Match: {exact_match}")

    logger.info("\n=== F1 Score Test ===")
    f1_score = compute_f1(prediction, ground_truth)
    logger.info(f"F1 Score: {f1_score}")
