from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not correctly stacked. It also considers the number of blocks that need to be moved to achieve the correct stacking.

    # Assumptions:
    - The goal state specifies the desired stacking of blocks.
    - Blocks can only be moved one at a time.
    - The arm can hold only one block at a time.

    # Heuristic Initialization
    - Extract the goal conditions for each block, specifically the desired "on" relationships.
    - Extract the initial state to determine the current stacking of blocks.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current stacking of blocks and compare it with the goal stacking.
    2. For each block, check if it is in the correct position:
       - If a block is not in the correct position, increment the heuristic value by 1.
    3. For each block that is not in the correct position, check if it is on top of another block that is also not in the correct position:
       - If so, increment the heuristic value by 1 to account for the additional move required to unstack the blocks.
    4. If a block is being held by the arm, increment the heuristic value by 1 to account for the action required to place it.
    5. The total heuristic value is the sum of all these increments.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and initial state."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (not needed for this heuristic).

        # Extract the goal "on" relationships.
        self.goal_on = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "on":
                block, under_block = args
                self.goal_on[block] = under_block

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state  # Current world state.

        # Track the current "on" relationships.
        current_on = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "on":
                block, under_block = args
                current_on[block] = under_block

        # Track which blocks are being held.
        holding_block = None
        for fact in state:
            if match(fact, "holding", "*"):
                holding_block = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize the heuristic cost.

        # Check each block's position against the goal.
        for block, goal_under_block in self.goal_on.items():
            if block in current_on:
                current_under_block = current_on[block]
                if current_under_block != goal_under_block:
                    total_cost += 1  # Block is not in the correct position.
                    # Check if the block is on top of another block that is also not in the correct position.
                    if current_under_block in self.goal_on and self.goal_on[current_under_block] != current_on.get(current_under_block, None):
                        total_cost += 1  # Additional move required to unstack.
            else:
                total_cost += 1  # Block is not in any stack (on-table or being held).

        # If a block is being held, add 1 to the cost to account for placing it.
        if holding_block:
            total_cost += 1

        return total_cost
