"""OlympiadBench (Open-ended Math) dataset handler."""
import pandas as pd
from typing import List, Dict
from .base import DatasetHandler


class OlympiadBenchHandler(DatasetHandler):
    name = "olympiadbench"
    default_train_path = "data/olympiadbench/OE_TO_maths_en_COMP.parquet"
    default_test_path = "data/olympiadbench/OE_TO_maths_en_COMP.parquet"
    default_max_tokens = 2048
    
    def __init__(self):
        from utils.reward_score import math as math_reward
        self.reward_module = math_reward
    
    def load_data(self, path: str, split: str = "train", max_samples: int = None) -> List[Dict]:
        df = pd.read_parquet(path)
        task_datas = []
        for _, row in df.iterrows():
            # final_answer is a list, take first element
            answer = row["final_answer"][0] if len(row["final_answer"]) > 0 else ""
            task_datas.append({
                "messages": [{"role": "user", "content": row["question"]}],
                "ground_truth": answer,
                "subject": row.get("subject", "Math"),
                "subfield": row.get("subfield", ""),
            })
            if max_samples and len(task_datas) >= max_samples:
                break
        return task_datas
    
    def compute_reward(self, response: str, ground_truth: str) -> float:
        reward = self.reward_module.compute_score(response, ground_truth, method="strict")
        if reward == 0:
            reward = self.reward_module.compute_score(response, ground_truth, method="flexible")
        return reward
    
    def extract_answer(self, response: str) -> str:
        answer = self.reward_module.extract_solution(response, method="strict")
        if answer is None:
            answer = self.reward_module.extract_solution(response, method="flexible")
        return answer if answer is not None else ""
    
    def format_answer_for_check(self, answer: str) -> str:
        return f"\\boxed{{{answer}}}"

