import re
from llmsearch.run_prolog import run_prolog
import os
import argparse

class Node:
    def __init__(self, execution_type, level, predicate, args, result=None):
        self.execution_type = execution_type
        self.level = level
        self.predicate = predicate
        self.args = args
        self.result = result if result else {}
        self.children = []

    def add_child(self, child):
        self.children.append(child)
    def set_result(self, result):
        self.result = result


class PrologTraceParser:

    def __init__(self, **kwargs):
        self.tree = kwargs.get('tree', None)
        self.min_level = kwargs.get('min_level', 0)
        self.replace_variable_map = kwargs.get('replace_variable_map', {})
        self.query_argument_meta_names = kwargs.get('search_meta_names', [])
        self.chains = []


    def _parse_result_arguments(self, line):
        #query(avgo,228.28) take the values between the brackets only
        chunk_regex = re.compile(r"\(([^)]+)\)")
        chunks = chunk_regex.findall(line)

        seperated_chunks = []
        for chunk in chunks:
            seperated_chunks += chunk.split(',')
        return seperated_chunks
        
    def parse_line(self, line):

        if ("path_result" in line or "true" in line.strip() or "false" in line.strip())\
            and ("Call" not in line and "Exit" not in line and "Redo" not in line and "Fail" not in line):
            if "true" in line.strip() or "false" in line.strip():
                line = f"Result = true" if "true" in line.strip() else f"Result = false"
            elif "path_result" in line:
                #path_result:query(jpm,107.52)
                result = self._parse_result_arguments(line.split(":")[1])
                return 'Result', None, None, None, {"Result": result}
            else:
                parts = line.split(',') if ',' in line else [line]
                result = {}

                result = {part.split('=')[0].strip(): part.split('=')[1].strip() for part in parts if len(part.strip()) > 0}
                return 'Result', None, None, None, result  # Treat results differently
        else:
            parts = line.split()
            execution_type = parts[0].strip(':')
            level = int(parts[1].strip('()'))
            predicate_parts = ' '.join(parts[2:-2]).split('(')
            predicate_name = predicate_parts[0]
            args = predicate_parts[1][:-1] if len(predicate_parts) > 1 else ''
            new_args = []
            for arg in args.split(','):
                if "_" not in arg:
                    new_args.append(arg)
                    continue
                if arg in self.replace_variable_map:
                    new_args.append(self.replace_variable_map[arg])
                else:
                    new_var = f"Candidate Variable_{len(self.replace_variable_map)}"
                    self.replace_variable_map[arg] = new_var
                    new_args.append(new_var)
            args = ', '.join(new_args)
            return execution_type, level, predicate_name, args, {}


    def build_tree(self, trace_lines):
        root = Node("Start of execution", 0, "Begining Search", [])
        current_path = [root]
        
        for index, line in enumerate(trace_lines):

            execution_type, level, predicate_name, args, result = self.parse_line(line)
            if index == 0:
                self.min_level = level

            if execution_type == "Result":
                if current_path is not None:
                    current_path[-1].result.update(result)
                    # current_path[-1].set_result(result.update(current_path[-1].result))
            else:
                while current_path[-1].level >= level:
                    current_path.pop()
                new_node = Node(execution_type, level, predicate_name, args)
                current_path[-1].add_child(new_node)
                current_path.append(new_node)
        

        self.tree = root
        return root

    def build_tree_new(self, trace_lines):
        root = Node("Start of execution", 0, "Begining Search", [])
        current_path = [root]
        completed_paths = []
        call_stack = {}
        
        for index, line in enumerate(trace_lines):

            execution_type, level, predicate_name, args, result = self.parse_line(line)
            if index == 0:
                self.min_level = level

            if execution_type == "Call":
                new_node = Node(execution_type, level, predicate_name, args)
                current_path[-1].add_child(new_node)
                current_path.append(new_node)
                call_stack[level] = [new_node] if level not in call_stack else call_stack[level] + [new_node]


            elif execution_type == "Exit":
                new_node = Node(execution_type, level, predicate_name, args)
                current_path[-1].add_child(new_node)
                current_path.append(new_node)

            elif execution_type == "Fail":
                new_node = Node(execution_type, level, predicate_name, args)
                new_node.result = {"Result": "Search Failed"}
                current_path[-1].add_child(new_node)
                current_path.pop()
                call_stack[level].pop()

            elif execution_type == "Redo":

                backtrack_node = call_stack[level][-1]
                current_path = current_path[:current_path.index(backtrack_node)]

                new_node = Node(execution_type, level, predicate_name, args)
                current_path[-1].add_child(new_node)
                current_path.append(new_node)
                call_stack[level] = call_stack[level] + [new_node]

            elif execution_type == "Result":
                if current_path is not None:
                    current_path[-1].result.update(result)
                    # current_path[-1].set_result(result.update(current_path[-1].result))


        self.tree = root
        return root

    
    def save_chains(self, node=None, chain=None):
        result_str = f" -> {node.result}" if node.result else ""
        if node is None:
            node = self.tree
            chain = []
        chain.append(node)
        if len(node.children) == 0:
            self.chains.append(chain)
        else:
            for child in node.children:
                self.save_chains(child, chain.copy())


    def print_chains(self):
        for chain in self.chains:
            search_chain = []
            for node in chain:
                if len(node.args) > 0:
                    args = f"({node.args.strip()})"
                else:
                    args = ""
                if len(node.predicate) > 0:
                    predicate = f" {node.predicate}"
                else:
                    predicate = ""
                if node.execution_type == "Exit":
                    node.execution_type = "Success"
                if len(node.result) > 0:
                    result = f" | {node.result}"
                else:
                    result = ""
                search_chain.append(f"{node.execution_type}:{predicate}{args}{result}")
            print(" -> ".join(search_chain))

    def print_tree(self, node, indent=0):
        result_str = f" -> {node.result}" if node.result else ""
        print('---------' * indent + f'{node.execution_type}: ({max(0, node.level - self.min_level)}) {node.predicate} ({node.args}){result_str}')
        for child in node.children:
            self.print_tree(child, indent + 1)


def sanitize_trace(trace_text):
    trace_lines = trace_text.strip().split('\n')

    trace_lines_sanitized = []
    for trace_line in trace_lines:
        if "protocol" in trace_line:
            continue
        if "forall" in trace_line:
            continue
        if "write" in trace_line:
            continue
        if "halt" in trace_line:
            continue
        if len(trace_line.strip()) == 0:
            continue

        trace_lines_sanitized.append(trace_line)


    return [line.strip() for line in trace_lines_sanitized if len(line.strip()) > 0]


def replace_results(trace_lines, results):
    predicted_results = []
    print(f"Trace lines: {trace_lines}")
    for index, line in enumerate(trace_lines):
        if "path_result" in line.strip():
            predicted_results.append(line)    


    print(f"Predicted results: {predicted_results}")
    print(f"Actual results: {results}")
    assert len(predicted_results) == len(results)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parse Prolog trace')
    parser.add_argument('--trace_file', type=str, help='Path to the file with prolog trace')
    parser.add_argument('--prolog_file', type=str, help='Path to the file with prolog code')

    args = parser.parse_args()

    # Example usage:
    with open(f"{args.trace_file}", 'r') as f:
        trace_text = f.read()

    trace_lines = sanitize_trace(trace_text)
    # trace_lines = trace_text.strip().split('\n')

    results = run_prolog(args.prolog_file)
    replace_results(trace_lines, results)

    parser = PrologTraceParser()
    tree_root = parser.build_tree_new(trace_lines)
    # parser.print_tree(tree_root)
    chains = parser.save_chains(node=tree_root, chain=[])
    parser.print_chains()