from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not correctly stacked. It also considers the number of blocks that need to be moved to achieve the correct stacking.

    # Assumptions
    - The goal state specifies the desired stacking of blocks.
    - Blocks can only be moved one at a time.
    - The arm can hold only one block at a time.
    - The heuristic does not need to be admissible, so it can overestimate the number of actions.

    # Heuristic Initialization
    - Extract the goal conditions to determine the desired stacking of blocks.
    - Identify the initial state of the blocks to compare against the goal state.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current stacking of blocks from the state.
    2. Compare the current stacking with the goal stacking.
    3. Count the number of blocks that are not in their correct position or not correctly stacked.
    4. For each block that is not in its correct position, estimate the number of actions required to move it to the correct position.
    5. Sum the estimated actions to get the heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and initial state."""
        self.goals = task.goals
        self.static = task.static

        # Extract goal conditions to determine the desired stacking of blocks.
        self.goal_stacking = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "on":
                block, under_block = args
                self.goal_stacking[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                self.goal_stacking[block] = "table"

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state

        # Extract the current stacking of blocks from the state.
        current_stacking = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "on":
                block, under_block = args
                current_stacking[block] = under_block
            elif predicate == "on-table":
                block = args[0]
                current_stacking[block] = "table"

        # Initialize the heuristic value.
        heuristic_value = 0

        # Compare the current stacking with the goal stacking.
        for block, goal_under_block in self.goal_stacking.items():
            current_under_block = current_stacking.get(block, None)
            if current_under_block != goal_under_block:
                # The block is not in its correct position.
                heuristic_value += 1

                # If the block is on another block, it needs to be moved.
                if current_under_block != "table":
                    heuristic_value += 1

        return heuristic_value
