from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at box1 loc_1_1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SokobanHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Sokoban domain.

    # Summary
    This heuristic estimates the number of actions required to push all boxes to their goal locations.
    It considers the distance between the robot and the boxes, as well as the distance between the boxes and their goal locations.

    # Assumptions
    - The robot can only push one box at a time.
    - The robot must be adjacent to a box to push it.
    - The heuristic does not account for deadlocks or blocked paths.

    # Heuristic Initialization
    - Extract goal locations for each box from the task goals.
    - Extract adjacency relationships between locations from the static facts.
    - Compute the shortest path distances between all pairs of locations using the adjacency relationships.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the robot and each box.
    2. For each box, compute the distance from its current location to its goal location.
    3. Compute the distance from the robot to each box.
    4. Sum the distances from the robot to the nearest box and from that box to its goal location.
    5. If multiple boxes are present, prioritize the box that is closest to its goal location.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each box.
        - Adjacency relationships between locations.
        - Compute shortest path distances between all pairs of locations.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract goal locations for each box.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                box, location = args
                self.goal_locations[box] = location

        # Build adjacency graph from static facts.
        self.adjacency = {}
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "adjacent":
                loc1, loc2, _ = args
                if loc1 not in self.adjacency:
                    self.adjacency[loc1] = set()
                if loc2 not in self.adjacency:
                    self.adjacency[loc2] = set()
                self.adjacency[loc1].add(loc2)
                self.adjacency[loc2].add(loc1)

        # Compute shortest path distances between all pairs of locations.
        self.distances = {}
        for loc in self.adjacency:
            self.distances[loc] = self._bfs(loc)

    def _bfs(self, start):
        """Compute the shortest path distances from `start` to all other locations using BFS."""
        distances = {start: 0}
        queue = [start]
        while queue:
            current = queue.pop(0)
            for neighbor in self.adjacency[current]:
                if neighbor not in distances:
                    distances[neighbor] = distances[current] + 1
                    queue.append(neighbor)
        return distances

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Identify the current location of the robot and each box.
        robot_location = None
        box_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at-robot":
                robot_location = args[0]
            elif predicate == "at":
                box, location = args
                box_locations[box] = location

        if not robot_location or not box_locations:
            return 0  # No boxes or robot not found.

        # Compute the heuristic value.
        total_cost = 0
        for box, current_location in box_locations.items():
            goal_location = self.goal_locations.get(box)
            if not goal_location:
                continue  # Box has no goal location.

            # Distance from box to its goal.
            box_to_goal = self.distances[current_location].get(goal_location, float('inf'))

            # Distance from robot to box.
            robot_to_box = self.distances[robot_location].get(current_location, float('inf'))

            # Total cost for this box.
            total_cost += robot_to_box + box_to_goal

        return total_cost
