from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(on b1 b2)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class BlocksworldHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Blocksworld domain.

    # Summary
    This heuristic estimates the number of actions required to reach the goal state by counting the number of blocks that are not in their correct position or not in the correct stack structure.

    # Assumptions
    - The goal state specifies the desired stack structure for all blocks.
    - Blocks can only be moved one at a time.
    - The heuristic does not need to be admissible, so it can overestimate the number of actions.

    # Heuristic Initialization
    - Extract the goal conditions to determine the desired stack structure.
    - Identify the blocks and their relationships in the goal state.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each block, check if it is in the correct position in the stack.
    2. If a block is not in the correct position, count the number of actions required to move it to the correct position.
    3. If a block is on the wrong block or not on the table when it should be, count the necessary actions to correct this.
    4. If a block is being held, count the necessary actions to place it in the correct position.
    5. Sum the total number of actions required to correct all blocks.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals
        self.static = task.static

        # Extract the goal structure: a dictionary mapping each block to the block it should be on.
        self.goal_structure = {}
        for goal in self.goals:
            if match(goal, "on", "*", "*"):
                _, block, under_block = get_parts(goal)
                self.goal_structure[block] = under_block
            elif match(goal, "on-table", "*"):
                _, block = get_parts(goal)
                self.goal_structure[block] = "table"

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state
        total_cost = 0

        # Check each block's position against the goal structure.
        for block, goal_under_block in self.goal_structure.items():
            # Find the current block under the block.
            current_under_block = None
            for fact in state:
                if match(fact, "on", block, "*"):
                    _, _, current_under_block = get_parts(fact)
                    break
                elif match(fact, "on-table", block):
                    current_under_block = "table"
                    break

            # If the block is not in the correct position, add to the cost.
            if current_under_block != goal_under_block:
                total_cost += 1

        # If the arm is holding a block, add to the cost.
        for fact in state:
            if match(fact, "holding", "*"):
                total_cost += 1

        return total_cost
