from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the current state of packages (whether they are in a vehicle or at a location) and the road network to compute the minimal number of actions needed.

    # Assumptions
    - Packages can be at a location or inside a vehicle.
    - Vehicles can move between connected locations via roads.
    - The capacity of vehicles is respected when picking up or dropping packages.
    - The heuristic does not need to be admissible, so it can overestimate the number of actions.

    # Heuristic Initialization
    - Extract the goal locations for each package from the task goals.
    - Extract the road network from the static facts to determine connectivity between locations.
    - Extract the capacity relationships between sizes from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each package, determine its current location (either directly or via the vehicle it is in).
    2. Calculate the shortest path (number of road segments) from the current location to the goal location.
    3. If the package is not in a vehicle, estimate the number of actions required to pick it up and transport it:
       - If the package is at a location, it needs to be picked up by a vehicle.
       - The vehicle may need to drive to the package's location.
    4. If the package is already in a vehicle, estimate the number of actions required to drive to the goal location and drop the package.
    5. Sum the estimated actions for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each package.
        - Road network from static facts.
        - Capacity relationships from static facts.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

        # Extract the road network.
        self.roads = {}
        for fact in static_facts:
            if match(fact, "road", "*", "*"):
                _, l1, l2 = get_parts(fact)
                if l1 not in self.roads:
                    self.roads[l1] = set()
                self.roads[l1].add(l2)

        # Extract capacity relationships.
        self.capacity_predecessors = {}
        for fact in static_facts:
            if match(fact, "capacity-predecessor", "*", "*"):
                _, s1, s2 = get_parts(fact)
                self.capacity_predecessors[s2] = s1

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at":
                obj, location = args
                current_locations[obj] = location
            elif predicate == "in":
                package, vehicle = args
                current_locations[package] = vehicle

        total_cost = 0  # Initialize action cost counter.

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package (could be a location or a vehicle).
            current_location = current_locations[package]

            # If the package is in a vehicle, find the vehicle's location.
            if current_location.startswith("v"):
                vehicle_location = current_locations[current_location]
                current_location = vehicle_location

            # Calculate the shortest path from current_location to goal_location.
            path_length = self._shortest_path_length(current_location, goal_location)

            # Estimate the number of actions:
            # - If the package is not in a vehicle, it needs to be picked up (1 action).
            # - Each road segment requires a drive action.
            # - If the package is in a vehicle, it needs to be dropped (1 action).
            if current_location.startswith("v"):
                total_cost += path_length + 1  # Drive and drop.
            else:
                total_cost += path_length + 2  # Pick up, drive, and drop.

        return total_cost

    def _shortest_path_length(self, start, goal):
        """
        Compute the shortest path length between two locations using BFS.

        @param start: The starting location.
        @param goal: The goal location.
        @return: The number of road segments in the shortest path.
        """
        if start == goal:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == goal:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.roads.get(current, []):
                queue.append((neighbor, distance + 1))

        return float('inf')  # If no path exists, return infinity.
