from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not correctly stacked. It also considers the number of blocks that need to be moved to achieve the correct stacking.

    # Assumptions:
    - The goal state specifies the desired stacking of blocks.
    - Blocks can only be moved one at a time.
    - The heuristic does not need to be admissible, so it can overestimate the number of actions.

    # Heuristic Initialization
    - Extract the goal conditions to determine the desired stacking of blocks.
    - Identify the current state of the blocks to compare with the goal state.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of each block (on-table, on another block, or being held).
    2. Compare the current state of each block with its goal state.
    3. Count the number of blocks that are not in their correct position.
    4. Count the number of blocks that are not correctly stacked (i.e., the block below them is not as specified in the goal).
    5. Sum these counts to estimate the number of actions required to reach the goal state.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (not used in this heuristic).

        # Extract the goal stacking of blocks.
        self.goal_stacking = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "on":
                block, under_block = args
                self.goal_stacking[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                self.goal_stacking[block] = None  # None indicates the block is on the table.

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state  # Current world state.

        # Track the current stacking of blocks.
        current_stacking = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "on":
                block, under_block = args
                current_stacking[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                current_stacking[block] = None  # None indicates the block is on the table.

        # Initialize the heuristic value.
        heuristic_value = 0

        # Compare the current stacking with the goal stacking.
        for block, goal_under_block in self.goal_stacking.items():
            current_under_block = current_stacking.get(block, None)

            if current_under_block != goal_under_block:
                heuristic_value += 1  # Block is not in the correct position.

            # If the block is on another block, check if the block below is correct.
            if current_under_block is not None and current_under_block != goal_under_block:
                heuristic_value += 1  # Block is not correctly stacked.

        return heuristic_value
