from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not correctly stacked. It also considers the need to clear blocks that are in the way of moving other blocks.

    # Assumptions:
    - The goal state specifies the desired arrangement of blocks, including their stacking and whether they are on the table.
    - The heuristic assumes that each block can be moved independently, but it accounts for the need to clear blocks before moving them.

    # Heuristic Initialization
    - Extract the goal conditions for each block, including their stacking and whether they are on the table.
    - Identify the current state of each block, including their stacking and whether they are on the table.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each block, check if it is in its correct position as specified by the goal.
    2. If a block is not in its correct position, increment the heuristic value by 1 (for the action to move it).
    3. If a block is on top of another block that is not its goal, increment the heuristic value by 1 (for the action to unstack it).
    4. If a block is not on the table but should be according to the goal, increment the heuristic value by 1 (for the action to put it down).
    5. If a block is on the table but should not be according to the goal, increment the heuristic value by 1 (for the action to stack it).
    6. If a block is being held, increment the heuristic value by 1 (for the action to put it down or stack it).
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (not needed for this heuristic).

        # Extract goal conditions for each block.
        self.goal_on = {}
        self.goal_on_table = set()
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "on":
                block, under_block = args
                self.goal_on[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                self.goal_on_table.add(block)

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state  # Current world state.

        # Track the current state of each block.
        current_on = {}
        current_on_table = set()
        holding = None
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "on":
                block, under_block = args
                current_on[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                current_on_table.add(block)
            elif predicate == "holding":
                holding = args[0]

        total_cost = 0  # Initialize the heuristic cost.

        # Check each block's position against the goal.
        for block in self.goal_on_table | set(self.goal_on.keys()):
            if block in self.goal_on_table:
                # Block should be on the table.
                if block not in current_on_table:
                    total_cost += 1  # Need to put it on the table.
            else:
                # Block should be on another block.
                goal_under_block = self.goal_on[block]
                if block not in current_on or current_on[block] != goal_under_block:
                    total_cost += 1  # Need to move it to the correct position.

            # Check if the block is on top of another block that is not its goal.
            if block in current_on and current_on[block] != self.goal_on.get(block, None):
                total_cost += 1  # Need to unstack it.

        # If a block is being held, it needs to be put down or stacked.
        if holding is not None:
            total_cost += 1

        return total_cost
