from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not correctly stacked. It also considers the number of blocks that need to be moved to achieve the desired stacking.

    # Assumptions:
    - The goal state specifies the desired stacking of blocks.
    - Blocks can only be moved one at a time.
    - The arm can hold only one block at a time.

    # Heuristic Initialization
    - Extract the goal conditions to determine the desired stacking of blocks.
    - Identify the current state of the blocks (on-table, on another block, or being held).

    # Step-By-Step Thinking for Computing Heuristic
    1. For each block, check if it is in its correct position as specified by the goal.
    2. If a block is not in its correct position, increment the heuristic value by 1 (since at least one action is needed to move it).
    3. If a block is on top of another block that is not its correct base, increment the heuristic value by 1 (since it needs to be moved).
    4. If a block is being held, increment the heuristic value by 1 (since it needs to be placed somewhere).
    5. If a block is on the table but should be on another block, increment the heuristic value by 1 (since it needs to be picked up and stacked).
    6. If a block is on another block but should be on the table, increment the heuristic value by 1 (since it needs to be moved to the table).
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (not needed for this heuristic).

        # Extract the desired stacking from the goal conditions.
        self.goal_stacking = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "on":
                block, base = args
                self.goal_stacking[block] = base
            elif predicate == "on-table":
                block = args[0]
                self.goal_stacking[block] = "table"

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state  # Current world state.

        # Track the current stacking of blocks.
        current_stacking = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "on":
                block, base = args
                current_stacking[block] = base
            elif predicate == "on-table":
                block = args[0]
                current_stacking[block] = "table"

        # Initialize the heuristic value.
        heuristic_value = 0

        # Check each block's current position against the goal.
        for block, goal_base in self.goal_stacking.items():
            current_base = current_stacking.get(block, None)

            if current_base != goal_base:
                heuristic_value += 1  # At least one action is needed to move the block.

        # If the arm is holding a block, increment the heuristic value.
        if any(match(fact, "holding", "*") for fact in state):
            heuristic_value += 1

        return heuristic_value
