import ast
import re
import json
import sympy
from prompts.game24 import *
from tasks.base import BaseTask
from collections import Counter
import networkx as nx

_scoring_number_pattern = re.compile(r'-?\d+(?:\.\d+)?')

def escape(x: str):
    return json.dumps(x)[1:-1]

def extract_numbers_from_root(root):
    """
    root is a game24 instance like "2 5 7 2"
    Returns a list of ints.
    """
    return [int(x) for x in root.split()]

_number_pat = re.compile(r'(?<![\d)])-?\d+')

def extract_numbers_from_eq(eq: str):
    """
    Return the integers that occur to the left of '='
    in the order they appear.
    """
    lhs = eq.split('=', 1)[0]
    return [int(m.group()) for m in _number_pat.finditer(lhs)]

def uses_exactly(root, eq):
    """
    Returns True if the multiset of numbers in the LHS of eq
    matches exactly the multiset 'root', else False.
    """
    root_nums = extract_numbers_from_root(root)
    eq_nums = extract_numbers_from_eq(eq)
    return Counter(root_nums) == Counter(eq_nums)

class Game24Task(BaseTask):
    """
    Input (x)   : a string of 4 numbers
    Output (y)  : a trajectory of 3 steps to reach 24
    Reward (r)  : 0 or 1, depending on whether the trajectory is correct
    Input Example: 
        1 2 3 4
    Output Example: 
        1 + 2 = 3 (left: 3 3 4)
        3 + 3 = 6 (left: 4 6)
        6 * 4 = 24 (left: 24)
        (1 + 2 + 3) * 4 = 24
    """

    def __init__(self, dataset):
        T,W,S = [nx.read_graphml(f"data/optimal_graphs/{dataset}_{x}.graphml") for x in "TWS"]
        self.T = T
        self.W = W
        self.S = S
        self.dataset = dataset
        self.samples = [n for n in T.nodes() if T.in_degree(n) == 0]
        self.has_unsolvable = any(n not in self.W for n in self.samples)
        self.max_depth = nx.dag_longest_path_length(T)
    
    @staticmethod
    def check_solution(root: tuple, eq: str):
        try:
            if not uses_exactly(root, eq):
                return False
            return sympy.simplify(eq.split("=")[0]) == 24
        except:
            return False
            
    def io_prompt_wrap(self, x: str) -> str:
        p = io_turns_mixed_prompt if self.has_unsolvable else io_turns_prompt
        convo = p.replace("{input}", x)
        return ast.literal_eval(convo)


    def cot_prompt_wrap(self, x: str) -> str:
        p = cot_turns_mixed_prompt if self.has_unsolvable else cot_turns_prompt
        convo = p.replace("{input}", x)
        return ast.literal_eval(convo)


    def add_finalize_io_answer_turn(self, conversation: list):
        p = finalize_io_mixed_follow_up_prompt if self.has_unsolvable else finalize_io_follow_up_prompt
        conversation.append(p)
        return conversation

    @staticmethod
    def propose_convo_wrap(node: str) -> str:
        convo = propose_prompt.replace("{input}", escape(node))
        return ast.literal_eval(convo)

    @staticmethod
    def apply_convo_wrap(node: str, move: str) -> str:
        convo = apply_prompt.replace("{input}", escape(node))
        convo = convo.replace("{step}", escape(move))
        return ast.literal_eval(convo)

    @staticmethod
    def add_apply_follow_up_turn(conversation: list):
        conversation.append(apply_follow_up_prompt)
        return conversation

    @staticmethod
    def get_move_verification_prompts(node: str, move: str):
        # returns dict {prompt_name: prompt} for given move, originating from node
        math = verify_math_prompt.replace("{input}", escape(move))
        selection = verify_numbers_selection_prompt.replace("{input}", escape(node)).replace("{steps}", escape(move))
        d = dict(is_math_correct_prompt = math, is_selection_correct=selection)
        d = {k:ast.literal_eval(v) for k,v in d.items()}
        return d

    @staticmethod
    def get_node_verification_prompts(node: str, next_node: str, move: str):
        # returns dict {prompt_name: prompt} for given move, originating from node, resulting in next_node
        p = verify_leftovers_prompt.replace("{input}", escape(node))
        p = p.replace("{step}", escape(move))
        p = p.replace("{result}", escape(next_node))
        d = dict(is_applied_correct = p)
        d = {k:ast.literal_eval(v) for k,v in d.items()}
        return d

    @staticmethod
    def add_verification_turn(conversation: list):
        conversation.append(verification_follow_up_prompt)
        return conversation

    '''
    @staticmethod
    def has_finished_convo_wrap(node: str) -> str:
        convo = has_finished_prompt.replace("{input}", escape(node))
        return ast.literal_eval(convo)

    @staticmethod
    def is_solution_convo_wrap(node: str) -> str:
        convo = is_solution_prompt.replace("{input}", escape(node))
        return ast.literal_eval(convo)
    '''

    @staticmethod
    def is_irreducible_node(node):
        return len(node.split(" ")) == 1

    @staticmethod
    def is_solution_node(node):
        try:
            return float(node) == 24
        except:
            return False
    
    @staticmethod
    def value_prompt_wrap(node: str) -> str:
        return ast.literal_eval(value_prompt.replace("{input}", node))

    @staticmethod
    def add_value_turn(conversation: list):
        conversation.append(value_follow_up_prompt)
        return conversation
    
    @staticmethod
    def value_outputs_unwrap(answer: str) -> float:
        value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}
        return sum([answer.count(k) * v for k,v in value_map.items()])

    @staticmethod
    def finalize_answer_prompt_wrap(x, path) -> str:
        steps = "\n".join(path)
        return ast.literal_eval(finalize_answer_prompt.replace("{steps}", steps).replace("{input}", x))

    @staticmethod
    def add_finalize_answer_turn(conversation: list):
        conversation.append(finalize_answer_follow_up_prompt)
        return conversation

    @staticmethod
    def shortcut_prompt_wrap(parent: str, child: str) -> str:
        return ast.literal_eval(shortcut_prompt.replace("{parent}", parent).replace("{child}", child))

    @staticmethod
    def add_shortcut_turn(conversation: list):
        conversation.append(shortcut_follow_up_prompt)
        return conversation


    # for tasks with unsolvable samples
    def evaluate_results(self, results, unsolved_samples=None):
        '''
        Results is a df with two columns "root" and "solution"
        unsolved samples are a list of roots that have no solution and are thus deemed unsolvable
            if None: unsolved samples are considered to be in results and have None as solutions!
        Returns df with new columns:
        "is_solvable": whether x could have been solved
        "was_solved": whether a correct solution was proposed (True)
        "is_correct": whether
        '''
        # add unsolved samples
        if unsolved_samples is not None:
            u = pd.DataFrame({"root": unsolved_samples, "solution": None})
            results = pd.concat([results, u], axis=0)
        
        results["is_solvable"] = results.root.apply(lambda x: x in self.W.nodes)
        results["deemed_solvable"] = results.solution.apply(lambda x: x is not None)
        
        results["solvable_check"] = results.is_solvable == results.deemed_solvable
        
        results["eq_correct"] = results.apply(lambda row: self.check_solution(row.root, row.solution), axis=1)
        
        results["is_correct"] = results.apply(lambda row: row.solvable_check & row.eq_correct if row.is_solvable else row.solvable_check , axis=1)
        return results