from typing import List

import numpy as np
import re
import sympy

from inference_rlhf.code.tasks.base import BaseDataLoader

QUESTION_FORMAT = "\nInput: {question}\nSteps:\n"
ANSWER_FORMAT = "{answer}"
SEP = ""

class Game24DataLoader(BaseDataLoader):
    """
    DataLoader for the Game24 dataset.
    """
    def __init__(self, cfg):
        super().__init__(
            cfg, 
            question_format = QUESTION_FORMAT,
            answer_format = ANSWER_FORMAT,
            sep = SEP, 
            question_examples=[], 
            answer_examples=[],
        )

        self.questions = [" ".join(map(str, q)) for q in self.questions]
        self.answer_patterns = cfg.policy.answer_patterns
    
    def judge_correct(self, expression: str, prompt_idx: int) -> bool:
        """
        Judges the correctness of the expression by checking if it equals 24 and uses the numbers in the question.

        Args:
            expression (str): The expression generated by the model to be evaluated.
            prompt_idx (int): Not used for this task since the answer is always 24.

        Returns:
            bool: True if the model's expression is considered correct, False otherwise.
        """
        final_answer = self._final_answer(expression, prompt_idx)

        if final_answer is None:
            return False
        
        return final_answer == 24

    def extract_answer(self, solution_str: str) -> int:
        """
        Extracts the final expression from the solution string.

        Args:
            solution_str (str): The solution string to extract the final expression from.

        Returns:
            int: The extracted final expression.
        """
        return solution_str.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0].strip()
    
    def _final_answer(self, expression: str, prompt_idx: int) -> int:
        """
        Extracts the final answer from the expression.

        Args:
            expression (str): The expression to extract the final answer from.
            prompt_idx (int): The dataset index of the question, which is used to get the question numbers.

        Returns:
            int: The extracted final answer, or -1 if the expression doesn't use the numbers in the question.
        """
        question = self.questions[prompt_idx]

        # Get numbers used in the solution
        numbers = re.findall(r'\d+', expression)

        # Get numbers used in the question
        problem_numbers = re.findall(r'\d+', question)

        # Check if solution uses the numbers in the question
        if sorted(numbers) != sorted(problem_numbers):
            return -1
        
        # Try parsing the expression
        try:
            return int(sympy.simplify(expression))
        except Exception as e:
            return -1




