from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at box1 loc_1_1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class SokobanHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Sokoban domain.

    # Summary
    This heuristic estimates the number of actions required to push all boxes to their goal locations.
    It considers the distance between the robot and the boxes, as well as the distance between the boxes and their goal locations.

    # Assumptions
    - The robot can only push one box at a time.
    - The robot must be adjacent to a box to push it.
    - The heuristic does not account for obstacles or deadlocks, as it is designed to be efficiently computable.

    # Heuristic Initialization
    - Extract goal locations for each box from the task goals.
    - Extract adjacency relationships between locations from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the robot and all boxes.
    2. For each box, compute the Manhattan distance between its current location and its goal location.
    3. Compute the Manhattan distance between the robot and each box.
    4. Sum the distances for all boxes, weighted by the number of actions required to push them to their goals.
    5. Add the distance between the robot and the nearest box to account for the robot's movement.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each box.
        - Adjacency relationships between locations.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map locations to their coordinates.
        self.location_to_coords = {}
        for fact in static_facts:
            if match(fact, "adjacent", "*", "*", "*"):
                loc1, loc2, _ = get_parts(fact)[1:]
                if loc1 not in self.location_to_coords:
                    self.location_to_coords[loc1] = self._parse_location(loc1)
                if loc2 not in self.location_to_coords:
                    self.location_to_coords[loc2] = self._parse_location(loc2)

        # Store goal locations for each box.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                box, location = args
                self.goal_locations[box] = location

    def _parse_location(self, location):
        """Convert a location string (e.g., 'loc_1_2') into coordinates (e.g., (1, 2))."""
        parts = location.split("_")
        return (int(parts[1]), int(parts[2]))

    def _manhattan_distance(self, loc1, loc2):
        """Compute the Manhattan distance between two locations."""
        return abs(loc1[0] - loc2[0]) + abs(loc1[1] - loc2[1])

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track the current location of the robot and all boxes.
        robot_location = None
        box_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at-robot":
                robot_location = args[0]
            elif predicate == "at":
                box, location = args
                box_locations[box] = location

        # If the robot's location is not found, return a high heuristic value.
        if not robot_location:
            return float("inf")

        # Convert locations to coordinates.
        robot_coords = self.location_to_coords[robot_location]
        box_coords = {box: self.location_to_coords[loc] for box, loc in box_locations.items()}
        goal_coords = {box: self.location_to_coords[loc] for box, loc in self.goal_locations.items()}

        total_cost = 0  # Initialize action cost counter.

        # Compute the cost for each box.
        for box in box_locations:
            current = box_coords[box]
            goal = goal_coords[box]
            distance = self._manhattan_distance(current, goal)
            total_cost += distance * 2  # Each push requires at least 2 actions (move and push).

        # Add the distance between the robot and the nearest box.
        if box_coords:
            nearest_box_distance = min(
                self._manhattan_distance(robot_coords, box_coords[box]) for box in box_coords
            )
            total_cost += nearest_box_distance

        return total_cost
