from typing import List

from inference_rlhf.code.tasks.base import BaseDataLoader
from inference_rlhf.code.helpers.math_utils import extract_answer as _extract_answer, judge_correct as _judge_correct

QUESTION_FORMAT = "Question: {question}\nSolution: "
ANSWER_FORMAT = "{answer}"
SEP = ""

class GSM8KDataLoader(BaseDataLoader):
    """
    DataLoader for the GSM8K dataset.
    """
    def __init__(self, cfg): 
        super().__init__(
            cfg, 
            question_format = QUESTION_FORMAT,
            answer_format = ANSWER_FORMAT,
            sep = SEP, 
            question_examples=[], 
            answer_examples=[],
        )

    def extract_groundtruth(self, answers: List[str]) -> List[str]:
        """
        Extract the final answers from a list of ground truth solution strings.

        Args:
            answers (List[str]): List of ground truth solution strings.

        Returns:
            List[str]: List of extracted final answers.
        """
        return [self._extract_groundtruth(answer) for answer in answers]
    
    def _extract_groundtruth(self, ground_truth: str) -> str:
        """
        Extract the final answer from the ground truth string.

        Args:
            ground_truth (str): The ground truth string to extract the answer from.

        Returns:
            str: The extracted answer.
        """
        x = ground_truth.split("#### ")[1].strip().replace(",", "").replace("$", "")
        try:
            _ = float(x)
            return x # just return string
        except:
            raise ValueError(
                f"Warning: Error should raise since the extracted groundtruth string {x} \
                cannot be converted to float"
            )

    def judge_correct(self, answer: str, prompt_idx: int) -> bool:
        """
        Compares the extract model answer with the ground truth using the Eleuther harness code in math utils.

        Args:
            answer (str): The answer generated by the model to be evaluated.
            prompt_idx (int): The dataset index of the question whose ground truth answer is used for comparison.

        Returns:
            bool: True if the model's answer is considered correct, False otherwise.
        """
        ground_truth = self.answers[prompt_idx]
        return _judge_correct(answer, ground_truth)

    def extract_answer(self, solution_str: str) -> str:
        """
        Extracts the answer from the solution string using the Eleuther harness code in math utils.

        Args:
            solution_str (str): The solution string to extract the answer from.

        Returns:
            str: The extracted answer.
        """
        return _extract_answer(solution_str)




