import re
import os
import sympy
import pandas as pd
from tot.tasks.base import Task, DATA_PATH
from tot.prompts.game24 import * 

def get_current_numbers(y: str) -> str:
    last_line = y.strip().split('\n')[-1]
    return last_line.split('left: ')[-1].split(')')[0]


class Game24Task(Task):
    """
    Input (x)   : a string of 4 numbers
    Output (y)  : a trajectory of 3 steps to reach 24
    Reward (r)  : 0 or 1, depending on whether the trajectory is correct
    Input Example: 
        1 2 3 4
    Output Example: 
        1 + 2 = 3 (left: 3 3 4)
        3 + 3 = 6 (left: 4 6)
        6 * 4 = 24 (left: 24)
        (1 + 2 + 3) * 4 = 24
    """
    def __init__(self, file='24.csv'):
        """
        file: a csv file (fixed)
        """
        super().__init__()
        path = os.path.join(DATA_PATH, '24', file)
        self.data = list(pd.read_csv(path)['Puzzles'])
        self.value_cache = {}
        self.steps = 4
        self.stops = ['\n'] * 4
        # add timestamp to the filename + txt
        self.df_errors = pd.DataFrame(columns=['Category', 'Sub-category', 'Input', 'Steps', 'Reason'])
        self.df_replacements = pd.DataFrame(columns=['Input', 'Steps', 'Corrected-steps'])
        self.csv_filename = f'game24_{file.split(".")[0]}_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.csv'
        self.prompt_df = pd.DataFrame(columns=['input_x', 'input_y_second_to_last', 'final_prompt'])
        self.replace_incorrect_results = False

    def __len__(self) -> int:
        return len(self.data)
    
    def get_input(self, idx: int) -> str:
        return self.data[idx]

    def test_output(self, idx: int, output: str):
        expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
        numbers = re.findall(r'\d+', expression)
        problem_numbers = re.findall(r'\d+', self.data[idx])
        if sorted(numbers) != sorted(problem_numbers):
            return {'r': 0}
        try:
            # print(sympy.simplify(expression))
            return {'r': int(sympy.simplify(expression) == 24)}
        except Exception as e:
            # print(e)
            return {'r': 0}
        
    def solvercheck_cot_prompt_out(self, x, y, ans):
        # check if the answer is 24 (i.e., evaluate the expression '(6 - 4) * (4 + 8) = 24' and check if correct)
        expression = ans.split('=')[0]
        numbers = re.findall(r'-?\d*\.\d+|-?\d+', expression)
        problem_numbers = re.findall(r'-?\d*\.\d+|-?\d+', x)
        try:
            value = sympy.simplify(expression)
            if int(value != 24): # needs to be corrected for the COT proposal prompt
                self.df_errors = self.df_errors._append({'Category': 'COT', 'Sub-category': 'INCORRECT(E1)', 'Input': x, 'Steps': y, 'Reason': 'Expression does not equal 24'}, ignore_index=True)
        except Exception as e: # brackets issue
            print(e, "expression:", expression, "x:", x, "y:", y)
            self.df_errors = self.df_errors._append({'Category': 'COT', 'Sub-category': 'INCORRECT(E2)', 'Input': x, 'Steps': y, 'Reason': 'Expression cannot be parsed correctly'}, ignore_index=True)
        if sorted(numbers) != sorted(problem_numbers): # needs to be corrected for the COT proposal prompt
            self.df_errors = self.df_errors._append({'Category': 'COT', 'Sub-category': 'INCORRECT(E3)', 'Input': x, 'Steps': y, 'Reason': 'Different numbers have been used'}, ignore_index=True)

    def solvercheck_propose_prompt_out(self, x, y, last_line, current_numbers): # error logs arranged in the order of precedence, but we will still log all of them
        assert "left: " in last_line, f"The last line does not have the 'left:' part: {last_line}"
        y_correct = False # flag to check if y needs to be corrected
        # check the current_numbers have correct numbers 
        if len(y.strip().split('\n')) >= 2: # if the last line is not the first line
            second_to_last_line = y.strip().split('\n')[-2].split('left: ')[-1].split(')')[0]
        else: 
            second_to_last_line = x
        try:
            second_to_last_line_numbers = re.findall(r'-?\d*\.\d+|-?\d+', second_to_last_line) # second_to_last_line example - '4 6 4 8'
            last_line_expr = last_line.split('(left: ')[0] # last_line_expr example - '4 + 8 = 12'
            last_line_expr_lhs_numbers = re.findall(r'-?\d*\.\d+|-?\d+', last_line_expr.split('=')[0]) # last_line_expr_lhs_numbers example - '4 8'
            last_line_expr_rhs_number = re.findall(r'-?\d*\.\d+|-?\d+', last_line_expr.split('=')[1])[0] # last_line_expr_rhs_number example - '12'
            last_line_left_numbers = re.findall(r'-?\d*\.\d+|-?\d+', current_numbers) # current_numbers example - '4 6 12'
        except Exception as e:
            print(e)
            print(x, y)
            return None

        # second_to_last_line_numbers - last_line_expr_lhs_numbers + last_line_expr_rhs_number = last_line_left_numbers
        for number in last_line_expr_lhs_numbers:
            if number in second_to_last_line_numbers:
                second_to_last_line_numbers.remove(number)
            else:
                self.df_errors = self.df_errors._append({'Category': 'PROPOSE', 'Sub-category': 'INCORRECT(E3)', 'Input': x, 'Steps': y, 'Reason': f'Invalid number {number} used in the expression'}, ignore_index=True) # e.g., 4 + 8 = 12, but 4 is not in the second_to_last_line_numbers
                return None # if there is an invalid number, we don't need to check the rest of the conditions

        # now since valid numbers are present, we can check if the expression is correct
        try:
            if not sympy.simplify(last_line_expr.replace('=', '-')+' < 0.01'):
                self.df_errors = self.df_errors._append({'Category': 'PROPOSE', 'Sub-category': 'INCORRECT(E4)', 'Input': x, 'Steps': y, 'Reason': f'Expression {last_line_expr.split("=")[0]} does not equal {last_line_expr_rhs_number}'}, ignore_index=True)
                cut_string = last_line.split('=')[1].split('left')[0].strip() # string between '=' and 'left' in y; keep ( intact so that it is easy to replace
                eval_res = eval(last_line_expr.split("=")[0])
                if eval_res == int(eval_res): eval_res = int(eval_res) # if the result is an integer, convert it to an integer (e.g., 12.0 to 12)
                last_line_correct = last_line.replace(cut_string, str(eval_res)+" (") # replace the incorrect rhs result with the correct one
                y_correct = y.replace(last_line, last_line_correct) # replace the incorrect last line with the correct one
                # print(f"Post simplify check - last_line_expr: {last_line_expr}, y: {y}, y_correct: {y_correct}")
                if y_correct == y: raise ValueError(f'Error in replacing the left numbers: {y_correct} == {y}') # if the replacement is not done, raise an error
                last_line = last_line_correct # update last_line to the corrected value for the next checks
                y = y_correct # update y to the corrected value for the next checks
        except Exception as e: 
            print(e)
            print("Ignoring the exception and continuing...") 

        eval_res = eval(last_line_expr.split("=")[0])
        if eval_res == int(eval_res): eval_res = int(eval_res)
        ideal_rhs_number = str(eval_res)
        ideal_left_numbers = sorted(second_to_last_line_numbers + [ideal_rhs_number])
        if len(ideal_left_numbers) > len(last_line_left_numbers):
            missing_numbers = [number for number in ideal_left_numbers if number not in last_line_left_numbers]
            self.df_errors = self.df_errors._append({'Category': 'PROPOSE', 'Sub-category': 'INCORRECT(E5)', 'Input': x, 'Steps': y, 'Reason': f'Numbers {missing_numbers} are missing in the last line'}, ignore_index=True)
            last_line_correct = last_line.replace('left: ' + current_numbers, 'left: ' + ' '.join(map(str, ideal_left_numbers))) # replace the incorrect left numbers with the correct ones
            y_correct = y.replace(last_line, last_line_correct) # replace the incorrect last line with the correct one
            # print(f"Post missing numbers check - last_line_expr: {last_line_expr}, y: {y}, y_correct: {y_correct}")
            if y_correct == y: raise ValueError(f'Error in replacing the left numbers: {y_correct} == {y}') # if the replacement is not done, raise an error
        elif len(ideal_left_numbers) < len(last_line_left_numbers):
            extra_numbers = [number for number in last_line_left_numbers if number not in ideal_left_numbers]
            self.df_errors = self.df_errors._append({'Category': 'PROPOSE', 'Sub-category': 'INCORRECT(E6)', 'Input': x, 'Steps': y, 'Reason': f'Numbers {extra_numbers} are extra in the last line'}, ignore_index=True)
            last_line_correct = last_line.replace('left: ' + current_numbers, 'left: ' + ' '.join(map(str, ideal_left_numbers))) # replace the incorrect left numbers with the correct ones
            y_correct = y.replace(last_line, last_line_correct) # replace the incorrect last line with the correct one
            # print(f"Post extra numbers check - last_line_expr: {last_line_expr}, y: {y}, y_correct: {y_correct}")
            if y_correct == y: raise ValueError(f'Error in replacing the left numbers: {y_correct} == {y}')
        else: # len(ideal_left_numbers) == len(last_line_left_numbers)
            if sorted(ideal_left_numbers) != sorted(last_line_left_numbers):
                incorrect_numbers = [number for number in last_line_left_numbers if number not in ideal_left_numbers]
                expected_numbers = [number for number in ideal_left_numbers if number not in last_line_left_numbers]
                self.df_errors = self.df_errors._append({'Category': 'PROPOSE', 'Sub-category': 'INCORRECT(E7)', 'Input': x, 'Steps': y, 'Reason': f'Numbers {incorrect_numbers} are incorrect in the last line, expected {expected_numbers}'}, ignore_index=True)
                last_line_correct = last_line.replace('left: ' + current_numbers, 'left: ' + ' '.join(map(str, ideal_left_numbers))) # replace the incorrect left numbers with the correct ones
                y_correct = y.replace(last_line, last_line_correct) # replace the incorrect last line with the correct one
                # print(f"Post incorrect numbers check - last_line_expr: {last_line_expr}, y: {y}, y_correct: {y_correct}")
                if y_correct == y: raise ValueError(f'Error in replacing the left numbers: {y_correct} == {y}')

        return y_correct

    def value_prompt_wrap(self, x: str, y: str) -> tuple: # returns (prompt, y_correct [not None if correction is needed])
        last_line = y.strip().split('\n')[-1] # last_line example - '4 + 8 = 12 (left: 4 6 12)'
        if 'left: ' not in last_line:  # last step - that's usually the case because of the cot prompt format
            ans = last_line.lower().replace('answer: ', '')
            self.solvercheck_cot_prompt_out(x, y, ans)
            return value_last_step_prompt.format(input=x, answer=ans), None
        current_numbers = get_current_numbers(y)
        try:
            y_correct = self.solvercheck_propose_prompt_out(x, y, last_line, current_numbers) # False, None, or a string
        except Exception as e:
            print(e)
            y_correct = None
        if self.replace_incorrect_results:
            if y_correct is None: # if there is an error, we don't need to proceed
                return None, None
            if y_correct is False: # if there is no error, but no correction is needed
                return value_prompt.format(input=current_numbers), None
            current_numbers = get_current_numbers(y_correct) # if there is a correction, we need to get the corrected current_numbers
            return value_prompt.format(input=current_numbers), y_correct
        else: # if we don't need to replace incorrect results, y_correct will always be only used for logging
            return value_prompt.format(input=current_numbers), y_correct
    
    @staticmethod
    def standard_prompt_wrap(x: str, y:str='') -> str:
        return standard_prompt.format(input=x+"\n"+y)

    @staticmethod
    def cot_prompt_wrap(x: str, y:str='') -> str:
        return cot_prompt.format(input=x, steps=y)
    
    @staticmethod
    def propose_prompt_wrap(x: str, y: str='') -> str:
        current_numbers = get_current_numbers(y if y else x) # if y is empty, we are at the start; next time current_numbers will have the values from the 'left' part of the last line of y
        if current_numbers == '24': # last step - if the number is not 24, we don't really care
            prompt = cot_prompt.format(input=x, steps=y) # the output of the prompt (with input and steps) will be an answer (e.g., Answer: (1 + 8 / 4) * 8 = 24)
            # print([prompt])
        else:
            prompt = propose_prompt.format(input=current_numbers)
        return prompt
    
    @staticmethod
    def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
        all_ys = y.strip().split('\n')
        if len(all_ys) == 4 and 'left' in all_ys[-1].lower(): # if the last step has 'left' it means the cot prompt was not triggered
            return 0
        # value_names = [_.split('\n')[-1] for _ in value_outputs]
        value_names = []
        for value_output in value_outputs: # more generalized than before
            if 'impossible' in value_output.lower():
                value_names.append('impossible')
            elif 'likely' in value_output.lower():
                value_names.append('likely')
            elif 'sure' in value_output.lower():
                value_names.append('sure')
        value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}  # TODO: ad hoc
        value = sum(value * value_names.count(name) for name, value in value_map.items())
        return value