from typing import Dict, List, Optional
import numpy as np
import os, re, sympy
import datasets
from importlib import import_module
from .base import *
# from inference_rlhf.code.preprocess.preprocess_anonymous_generations import extract_answers, extract_results


QUESTION_FORMAT = "Question: {question}\nSolution: "
ANSWER_FORMAT = "{answer}"
SEP = ""

def find_numbers(x): 
  """Finds all numbers in a string."""
  escaped = re.escape(x)
  numbers = re.compile(
      r'-?\d+(?:\.\d+)?',
      re.MULTILINE | re.DOTALL | re.IGNORECASE,
  ).findall(escaped)
  return numbers

def find_answer2(x: str, answer_patterns: List[str], strict: bool = False):
  for answer_pattern in answer_patterns: 
    if answer_pattern in x:
        answer = x.split(answer_pattern)[-1]
        numbers = find_numbers(answer)
        if numbers:
            return int(numbers[0]) if '.' not in numbers[0] else float(numbers[0])
        
  if not strict:
    numbers = find_numbers(x)
    if numbers:
      return int(numbers[-1]) if '.' not in numbers[-1] else float(numbers[-1])
    
  return None

class AIME2025DataLoader(BaseDataLoader):
    """
    DataLoader for the AIME 2025 dataset.
    """
    def __init__(self, cfg):
        super().__init__(
            cfg, 
            question_format = QUESTION_FORMAT,
            answer_format = ANSWER_FORMAT,
            sep = SEP, 
            question_examples=[], 
            answer_examples=[],
        )

        self.answer_patterns = cfg.policy.answer_patterns

    def extract_answer(self, solution_str: str) -> int:
        """
        Extracts the answer from the solution string using regex parsing with answer patterns.
        TODO: replace with Eleuther evaluation harness

        Args:
            solution_str (str): The solution string to extract the answer from.

        Returns:
            int: The extracted answer.
        """
        for answer_pattern in self.answer_patterns: 
            escaped = re.escape(answer_pattern)
            ans_re = re.compile(fr"{escaped}(-?\d+(?:\.\d+)?)")
            match = ans_re.search(solution_str)
            if match:
                match_str = match.group(1).strip()
                match_str = match_str.replace(",", "")
                return int(match_str) if '.' not in match_str else float(match_str)
         
        return find_answer2(solution_str, self.answer_patterns, True)
    
    def judge_correct(self, answer: int, prompt_idx: int) -> bool:
        """
        Judges the correctness of the answer based on the ground truth answer.

        Args:
            answer (int): The answer generated by the model to be evaluated.
            prompt_idx (int): The dataset index of the question whose ground truth answer is used for comparison.

        Returns:
            bool: True if the model's answer is considered correct, False otherwise.
        """
        gt = self.answers[prompt_idx]
        if answer is None:
            return False
        
        return gt == answer



