import subprocess


class TreeNode:
    def __init__(self, path):
        self.value = path
        self.children = {}

    def add_child(self, path):
        key = tuple(path)
        if key not in self.children:
            self.children[key] = TreeNode(path)
        return self.children[key]

    def display(self, depth=0):
        print("  " * depth + "\n".join(self.value))
        for child in self.children.values():
            child.display(depth + 1)


def create_tree_from_strings(strings):
    root = TreeNode([])

    for string in strings:
        current_node = root
        path = []
        for line in string.split("\n"):
            path.append(line)
            current_node = current_node.add_child(path.copy())

    return root


def check_correctness(llm_answer, filename):
    with open("temp.txt", "w") as file:
        file.write(llm_answer)
    domain_file = "domain.pddl"
    val_output = subprocess.run(
        f"VAL/validate -t -v {domain_file} {filename} temp.txt",
        shell=True,
        capture_output=True,
        text=True,
    )
    if "Plan valid" in val_output.stdout:
        return True
    else:
        return False


def dfs_with_error_check(node, filename, error_limit=10):
    """
    DFS traversal to check leaf nodes with a function.

    Parameters:
    - node: The current node in the tree.
    - error_limit: Maximum number of errors allowed before failing.

    Returns:
    - A tuple (success, error_count):
      - success: True if any leaf node succeeds.
      - error_count: The number of errors encountered.
    """
    error_count = 0

    def dfs_helper(node):
        nonlocal error_count

        if error_count >= error_limit:
            return False

        if not node.children:
            is_success = check_correctness("\n".join(node.value), filename)
            if is_success:
                return True
            else:
                error_count += 1
                if error_count >= error_limit:
                    return False

        for child in node.children.values():
            if dfs_helper(child):
                return True

        return False

    success = dfs_helper(node)
    return success, error_count
