from typing import List

from inference_rlhf.code.tasks.base import BaseDataLoader
from inference_rlhf.code.helpers.math_utils import extract_answer as _extract_answer, judge_correct as _judge_correct

QUESTION_FORMAT = "Question: {question}\nSolution: "
ANSWER_FORMAT = "{answer}"
SEP = ""

class MATHDataLoader(BaseDataLoader):
    """
    DataLoader for the MATH dataset.
    """
    def __init__(self, cfg):
        super().__init__(
            cfg,
            question_format = QUESTION_FORMAT,
            answer_format = ANSWER_FORMAT,
            sep = SEP,
            question_examples=[],
            answer_examples=[],
        )
            
    def extract_groundtruth(self, answers: List[str]) -> List[str]:
        """
        Extract the final answers from a list of ground truth solution strings.

        Args:
            answers (List[str]): List of ground truth solution strings.

        Returns:
            List[str]: List of extracted final answers.
        """
        return [self.extract_answer(answer) for answer in answers]
    
    def judge_correct(self, answer: str, prompt_idx: int) -> bool:
        """
        Compares the extract model answer with the ground truth using the Eleuther harness code in math utils.

        Args:
            answer (str): The answer generated by the model to be evaluated.
            prompt_idx (int): The dataset index of the question whose ground truth answer is used for comparison.

        Returns:
            bool: True if the model's answer is considered correct, False otherwise.
        """
        ground_truth = self.answers[prompt_idx]
        return _judge_correct(answer, ground_truth)

    def extract_answer(self, solution_str: str) -> str:
        """
        Extracts the answer from the solution string using the Eleuther harness code in math utils.

        Args:
            solution_str (str): The solution string to extract the answer from.

        Returns:
            str: The extracted answer.
        """
        return _extract_answer(solution_str)

    


