from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at-robot loc_1_1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SokobanHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Sokoban domain.

    # Summary
    This heuristic estimates the number of actions required to push all boxes to their goal positions.
    It considers the distance between the robot and the boxes, as well as the distance between the boxes and their goal positions.

    # Assumptions
    - The robot can only push one box at a time.
    - The robot must be adjacent to a box to push it.
    - The heuristic does not account for deadlocks or blocked paths.

    # Heuristic Initialization
    - Extract goal positions for each box from the task goals.
    - Extract adjacency relationships between locations from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current positions of the robot and all boxes.
    2. For each box, calculate the Manhattan distance between its current position and its goal position.
    3. Calculate the Manhattan distance between the robot and each box.
    4. Sum the distances for all boxes and add the robot's distance to the nearest box.
    5. The heuristic value is the total sum of these distances.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal positions for each box.
        - Adjacency relationships between locations.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map locations to their adjacent locations using "adjacent" relationships.
        self.adjacency = {}
        for fact in static_facts:
            if match(fact, "adjacent", "*", "*", "*"):
                parts = get_parts(fact)
                loc1, loc2, _ = parts[1], parts[2], parts[3]
                if loc1 not in self.adjacency:
                    self.adjacency[loc1] = []
                self.adjacency[loc1].append(loc2)

        # Store goal positions for each box.
        self.goal_positions = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                box, location = args
                self.goal_positions[box] = location

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track the current positions of the robot and boxes.
        robot_position = None
        box_positions = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at-robot":
                robot_position = args[0]
            elif predicate == "at":
                box, location = args
                box_positions[box] = location

        total_cost = 0  # Initialize action cost counter.

        for box, goal_location in self.goal_positions.items():
            current_location = box_positions[box]
            if current_location == goal_location:
                continue  # Box is already at its goal.

            # Calculate Manhattan distance between current and goal positions.
            x1, y1 = map(int, current_location.split('_')[1:])
            x2, y2 = map(int, goal_location.split('_')[1:])
            distance = abs(x1 - x2) + abs(y1 - y2)
            total_cost += distance

            # Calculate Manhattan distance between robot and box.
            if robot_position:
                xr, yr = map(int, robot_position.split('_')[1:])
                robot_distance = abs(xr - x1) + abs(yr - y1)
                total_cost += robot_distance

        return total_cost
