from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not correctly stacked. It also considers the number of blocks that need to be moved to achieve the correct stacking.

    # Assumptions
    - The goal state specifies the desired stacking of blocks.
    - Blocks can only be moved one at a time.
    - The arm can hold only one block at a time.
    - Blocks must be clear to be picked up or stacked.

    # Heuristic Initialization
    - Extract the goal conditions to determine the desired stacking of blocks.
    - Identify the current state of the blocks (on-table, on, clear, etc.).

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of each block (on-table, on, clear, etc.).
    2. Compare the current state with the goal state to determine which blocks are not in their correct position.
    3. Count the number of blocks that are not in their correct position.
    4. For each block that is not in its correct position, estimate the number of actions required to move it to its correct position:
       - If a block is on the wrong block, it must be unstacked and then restacked.
       - If a block is on the table but should be on another block, it must be picked up and stacked.
       - If a block is on another block but should be on the table, it must be unstacked and put down.
    5. Sum the estimated actions to get the heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (not needed for this heuristic).

        # Extract the goal conditions for each block.
        self.goal_on = {}
        self.goal_on_table = set()
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "on":
                block, under_block = args
                self.goal_on[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                self.goal_on_table.add(block)

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state  # Current world state.

        # Track the current state of each block.
        current_on = {}
        current_on_table = set()
        current_clear = set()
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "on":
                block, under_block = args
                current_on[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                current_on_table.add(block)
            elif predicate == "clear":
                block = args[0]
                current_clear.add(block)

        total_cost = 0  # Initialize the heuristic cost.

        # Check each block to see if it is in its correct position.
        for block in self.goal_on_table:
            if block not in current_on_table:
                total_cost += 1  # Need to move this block to the table.

        for block, under_block in self.goal_on.items():
            if block in current_on:
                if current_on[block] != under_block:
                    total_cost += 2  # Need to unstack and restack this block.
            else:
                total_cost += 1  # Need to stack this block on the correct block.

        return total_cost
