"""MATH-500 dataset handler."""
import json
from typing import List, Dict
from .base import DatasetHandler


class MATH500Handler(DatasetHandler):
    name = "math500"
    default_train_path = "data/math-500/test.jsonl"
    default_test_path = "data/math-500/test.jsonl"
    default_max_tokens = 2048
    
    def __init__(self):
        from utils.reward_score import math as math_reward
        self.reward_module = math_reward
    
    def load_data(self, path: str, split: str = "train", max_samples: int = None) -> List[Dict]:
        task_datas = []
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    row = json.loads(line)
                    task_datas.append({
                        "messages": [{"role": "user", "content": row["problem"]}],
                        "ground_truth": row["answer"],
                        "problem": row["problem"],
                        "subject": row.get("subject", ""),
                        "level": row.get("level", 0),
                    })
                    if max_samples and len(task_datas) >= max_samples:
                        break
        return task_datas
    
    def compute_reward(self, response: str, ground_truth: str) -> float:
        reward = self.reward_module.compute_score(response, ground_truth, method="strict")
        if reward == 0:
            reward = self.reward_module.compute_score(response, ground_truth, method="flexible")
        return reward
    
    def extract_answer(self, response: str) -> str:
        answer = self.reward_module.extract_solution(response, method="strict")
        if answer is None:
            answer = self.reward_module.extract_solution(response, method="flexible")
        return answer if answer is not None else ""
    
    def format_answer_for_check(self, answer: str) -> str:
        return f"\\boxed{{{answer}}}"

