import re
from typing import Literal
from datasets import load_dataset

from .base import BaseBenchmark, BenchmarkSample


class AIME24Benchmark(BaseBenchmark):
    
    def __init__(self):
        super().__init__()
    
    def _load_samples(self) -> list[BenchmarkSample]:
        dataset = load_dataset(
            path="math-ai/aime24",
            split="test",
        )
        
        samples = []
        for item in dataset:
            solution = item.get("solution", "")
            answer = self._extract_boxed_answer(solution)
            
            samples.append(BenchmarkSample(
                question=item.get("problem", ""),
                expected_answer=answer,
                metadata={
                    "id": item.get("id"),
                    "url": item.get("url"),
                    "full_solution": solution,
                },
            ))
        
        return samples
    
    def _extract_boxed_answer(self, solution: str) -> str:
        pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
        matches = re.findall(pattern, solution)
        if matches:
            return matches[-1].strip()
        return ""
    
    def check_answer(self, predicted: str, expected: str) -> bool:
        pred_answer = self._extract_boxed_answer(predicted)
        if not pred_answer:
            numbers = re.findall(r"(?:^|[^\d])(\d+)(?:[^\d]|$)", predicted)
            if numbers:
                pred_answer = numbers[-1]
        
        pred_clean = pred_answer.strip().replace(",", "")
        exp_clean = expected.strip().replace(",", "")
        
        try:
            return float(pred_clean) == float(exp_clean)
        except (ValueError, TypeError):
            return pred_clean == exp_clean
    
    def format_prompt(self, question: str) -> str:
        return f"Please reason step by step, and put your final answer within \\boxed{{}}.\nProblem: {question}\nAnswer:"


class GPQADiamondBenchmark(BaseBenchmark):
    
    def __init__(self):
        super().__init__()
    
    def _load_samples(self) -> list[BenchmarkSample]:
        dataset = load_dataset(
            path="fingertap/GPQA-Diamond",
            split="test",
        )
        
        samples = []
        for item in dataset:
            samples.append(BenchmarkSample(
                question=item.get("question", ""),
                expected_answer=item.get("answer", "").strip().upper(),
                metadata={},
            ))
        
        return samples
    
    def check_answer(self, predicted: str, expected: str) -> bool:
        boxed_match = re.search(r"\\boxed\{([ABCD])\}", predicted, re.IGNORECASE)
        if boxed_match:
            return boxed_match.group(1).upper() == expected.strip().upper()
        
        pred_clean = predicted.strip().upper()
        exp_clean = expected.strip().upper()
        
        if pred_clean in ["A", "B", "C", "D"]:
            return pred_clean == exp_clean
        
        patterns = [
            r"(?:answer|choice)[\s:]*(?:is\s+)?([ABCD])\b",
            r"\(([ABCD])\)",
            r"^([ABCD])[\.\)\s]",
            r"([ABCD])$",
        ]
        
        for pattern in patterns:
            match = re.search(pattern, pred_clean, re.IGNORECASE)
            if match:
                return match.group(1).upper() == exp_clean
        
        return False
    
    def format_prompt(self, question: str) -> str:
        return f"Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{question}"


def get_benchmark(
    name: Literal["aime24", "gpqa_diamond"],
) -> BaseBenchmark:
    benchmarks = {
        "aime24": AIME24Benchmark,
        "gpqa_diamond": GPQADiamondBenchmark,
    }
    
    if name not in benchmarks:
        raise ValueError(f"Unknown benchmark: {name}. Available: {list(benchmarks.keys())}")
    
    return benchmarks[name]()
