from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at box1 loc_1_1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SokobanHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Sokoban domain.

    # Summary
    This heuristic estimates the number of actions required to push all boxes to their goal locations.
    It considers the distance between the robot, boxes, and goal locations, as well as the number of
    boxes that are already in their goal positions.

    # Assumptions
    - The robot can only push one box at a time.
    - The robot must be adjacent to a box to push it.
    - The goal is to have all boxes in their respective goal locations.

    # Heuristic Initialization
    - Extract goal locations for each box from the task goals.
    - Extract adjacency relationships between locations from static facts.
    - Precompute the shortest path distances between all pairs of locations using the adjacency information.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the robot and all boxes.
    2. For each box not in its goal location:
        a. Compute the distance from the robot to the box.
        b. Compute the distance from the box to its goal location.
        c. Add these distances to the heuristic value.
    3. If a box is already in its goal location, no additional actions are needed for that box.
    4. The heuristic value is the sum of the distances for all boxes not in their goal locations.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract goal locations for each box.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                box, location = args
                self.goal_locations[box] = location

        # Build adjacency graph from static facts.
        self.adjacency = {}
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "adjacent":
                loc1, loc2, _ = args
                if loc1 not in self.adjacency:
                    self.adjacency[loc1] = set()
                if loc2 not in self.adjacency:
                    self.adjacency[loc2] = set()
                self.adjacency[loc1].add(loc2)
                self.adjacency[loc2].add(loc1)

        # Precompute shortest path distances between all pairs of locations.
        self.distances = {}
        for loc in self.adjacency:
            self.distances[loc] = self._bfs(loc)

    def _bfs(self, start):
        """Perform BFS to compute shortest paths from `start` to all other locations."""
        distances = {start: 0}
        queue = [start]
        while queue:
            current = queue.pop(0)
            for neighbor in self.adjacency[current]:
                if neighbor not in distances:
                    distances[neighbor] = distances[current] + 1
                    queue.append(neighbor)
        return distances

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Identify the current location of the robot and all boxes.
        robot_location = None
        box_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at-robot":
                robot_location = args[0]
            elif predicate == "at":
                box, location = args
                box_locations[box] = location

        total_cost = 0  # Initialize heuristic cost.

        for box, goal_location in self.goal_locations.items():
            current_location = box_locations.get(box, None)
            if current_location != goal_location:
                # Compute distance from robot to box.
                if robot_location in self.distances and current_location in self.distances[robot_location]:
                    robot_to_box = self.distances[robot_location][current_location]
                else:
                    robot_to_box = float('inf')  # Unreachable.

                # Compute distance from box to goal.
                if current_location in self.distances and goal_location in self.distances[current_location]:
                    box_to_goal = self.distances[current_location][goal_location]
                else:
                    box_to_goal = float('inf')  # Unreachable.

                # Add to total cost.
                total_cost += robot_to_box + box_to_goal

        return total_cost
