import re
import sympy
import ast
import networkx as nx

def escape(x: str):
    return x.replace('"', '\\"')

def get_current_numbers(y: str) -> str:
    last_line = y.strip().split('\n')[-1]
    return last_line.split('left: ')[-1].split(')')[0]


WORD_RE = re.compile(r'\b(?:incorrect|correct)\b', re.I)      
def correct_is_last(text: str) -> bool:
    last_word = None
    for m in WORD_RE.finditer(text):
        last_word = m.group(0).lower()
    return last_word == 'correct'

class Game24Task():
    """
    Input (x)   : a string of 4 numbers
    Output (y)  : a trajectory of 3 steps to reach 24
    Reward (r)  : 0 or 1, depending on whether the trajectory is correct
    Input Example: 
        1 2 3 4
    Output Example: 
        1 + 2 = 3 (left: 3 3 4)
        3 + 3 = 6 (left: 4 6)
        6 * 4 = 24 (left: 24)
        (1 + 2 + 3) * 4 = 24
    """
    def __init__(self, dataset):
        """
        data: list of str, e.g. ["1 2 3 4"]
        """
        super().__init__()

        T,W,S = [nx.read_graphml(f"data/optimal_graphs/{dataset}_{x}.graphml") for x in "TWS"]
        self.dataset = dataset
        self.T = T
        self.W = W
        self.S = S
        self.samples = [n for n in T.nodes() if T.in_degree(n) == 0]
        # transform data into tot expected format
        self.steps = max(x.count(" ") for x in self.samples) + 1


        from tot.prompts_original_mistral import (standard_prompt, cot_prompt, propose_prompt, value_last_step_prompt, value_prompt, 
                                                verify_math_prompt, verify_numbers_selection_prompt, verify_leftovers_prompt)
        self.verify_math_prompt = verify_math_prompt
        self.verify_numbers_selection_prompt = verify_numbers_selection_prompt
        self.verify_leftovers_prompt = verify_leftovers_prompt
        self.standard_prompt = standard_prompt
        self.cot_prompt = cot_prompt
        self.propose_prompt = propose_prompt
        self.value_last_step_prompt = value_last_step_prompt
        self.value_prompt = value_prompt
        self.data = self.samples
        self.value_cache = {}

    def __len__(self) -> int:
        return len(self.data)
    
    def get_input(self, idx: int) -> str:
        return self.data[idx]
    
    def test_output(self, idx: int, output: str):
        expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
        numbers = re.findall(r'\d+', expression)
        problem_numbers = re.findall(r'\d+', self.data[idx])
        if sorted(numbers) != sorted(problem_numbers):
            return {'r': 0}
        try:
            return {'r': int(sympy.simplify(expression) == 24)}
        except Exception as e:
            return {'r': 0}
            
    def standard_prompt_wrap(self, x: str, y:str='') -> str:
        return self.standard_prompt.format(input=escape(x)) + escape(y)

    def cot_prompt_wrap(self, x: str, y:str='') -> str:
        return self.cot_prompt.format(input=x) + y
    
    def propose_prompt_wrap(self, x: str, y: str='') -> str:
        current_numbers = get_current_numbers(y if y else x)
        if current_numbers == '24':
            prompt = self.cot_prompt.format(steps=escape(y))
        else:
            prompt = self.propose_prompt.format(input=current_numbers)
        return prompt
    
    def value_prompt_wrap(self, x: str, y: str) -> str:
        last_line = y.strip().split('\n')[-1]
        if 'left: ' not in last_line:  # last step
            ans = last_line.lower().replace('answer: ', '')
            return self.value_last_step_prompt.format(input=x, answer=escape(ans))
        current_numbers = get_current_numbers(y)
        return self.value_prompt.format(input=current_numbers)
    
    def value_outputs_unwrap(self, x: str, y: str, value_outputs: list) -> float:
        if len(y.strip().split('\n')) == self.steps and 'answer' not in y.lower():
            return 0
        value_names = [_.split('\n')[-1].lower() for _ in value_outputs]
        value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}  # TODO: ad hoc
        value = sum(value * len([x for x in value_names if name in x]) for name, value in value_map.items())
        return value


    def get_verification_prompts(self, x, y):
        # returns dict {prompt_name: prompt}
        math, selection, leftovers = None, None, None
        try:
            ly = y.strip().split("\n")[-1]
            math = self.verify_math_prompt.replace("{input}", escape(y))
            ly = y.strip().split("\n")
            a = x if len(ly) == 1 else ly[-2]
            b = ly[-1]
            selection = self.verify_numbers_selection_prompt.replace("{input}", escape(a)).replace("{steps}", escape(b))
            leftovers = self.verify_leftovers_prompt.replace("{input}", escape(x)).replace("{steps}", escape(y))
        except:
            print("failed to setup prompts")
            print(x, y)
        d = dict(
                 is_math_correct_prompt = math, 
                 is_selection_correct_prompt=selection, 
                 is_leftovers_correct_prompt=leftovers,
                )
        d = {k:ast.literal_eval(v) for k,v in d.items()}
        return d
    
    def verification_outputs_unwrap(self, out):
        # out is a list of str
        o = [correct_is_last(o) for o in out]
        return sum(o) >= len(o)/2
    
