from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not correctly stacked.

    # Assumptions:
    - The goal state specifies the desired configuration of blocks, including their stacking order and whether they are on the table.
    - The heuristic assumes that each block can be moved independently, and the cost of moving a block is proportional to the number of blocks that need to be moved to achieve the goal configuration.

    # Heuristic Initialization
    - Extract the goal conditions for each block, including their stacking order and whether they are on the table.
    - Initialize data structures to store the current state of the blocks and their relationships.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of each block, including whether it is on the table, on another block, or being held.
    2. Compare the current state of each block with its goal state.
    3. For each block that is not in its goal position:
       - If the block is on the table but should be on another block, increment the heuristic by 1 (for the stack action).
       - If the block is on another block but should be on the table, increment the heuristic by 1 (for the unstack and putdown actions).
       - If the block is on the wrong block, increment the heuristic by 2 (for the unstack and stack actions).
    4. If a block is being held, increment the heuristic by 1 (for the putdown or stack action).
    5. Sum the total number of actions required to move all blocks to their goal positions.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (not needed for this heuristic).

        # Extract goal conditions for each block.
        self.goal_on_table = set()
        self.goal_on = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "on-table":
                self.goal_on_table.add(args[0])
            elif predicate == "on":
                self.goal_on[args[0]] = args[1]

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state

        # Track the current state of each block.
        current_on_table = set()
        current_on = {}
        holding = None
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "on-table":
                current_on_table.add(args[0])
            elif predicate == "on":
                current_on[args[0]] = args[1]
            elif predicate == "holding":
                holding = args[0]

        total_cost = 0  # Initialize the heuristic cost.

        # Check each block's current state against its goal state.
        for block in self.goal_on_table | set(self.goal_on.keys()):
            if block in self.goal_on_table:
                # Block should be on the table.
                if block not in current_on_table:
                    # Block is not on the table; it must be moved.
                    if block in current_on:
                        # Block is on another block; unstack and putdown.
                        total_cost += 2
                    elif holding == block:
                        # Block is being held; putdown.
                        total_cost += 1
            else:
                # Block should be on another block.
                goal_target = self.goal_on[block]
                if block in current_on:
                    if current_on[block] != goal_target:
                        # Block is on the wrong block; unstack and stack.
                        total_cost += 2
                elif holding == block:
                    # Block is being held; stack.
                    total_cost += 1
                else:
                    # Block is on the table but should be on another block; pickup and stack.
                    total_cost += 2

        return total_cost
