from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the current state of packages, vehicles, and the road network to compute an efficient estimate.

    # Assumptions
    - Packages can be either on the ground or inside a vehicle.
    - Vehicles can carry multiple packages, but their capacity is limited.
    - Roads are bidirectional, and vehicles can move freely between connected locations.

    # Heuristic Initialization
    - Extract goal locations for each package from the task's goal conditions.
    - Extract the road network and capacity information from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of each package and its goal location.
    2. Determine if the package is inside a vehicle or on the ground.
    3. If the package is inside a vehicle, determine the vehicle's current location.
    4. Compute the shortest path (number of road segments) between the package's current location and its goal location.
    5. If the package is not already at its goal, estimate the number of actions required:
       - If the package is on the ground, it must be picked up by a vehicle.
       - The vehicle must drive to the goal location.
       - The package must be dropped at the goal location.
    6. Sum the estimated actions for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals
        self.static = task.static

        # Extract goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

        # Extract the road network.
        self.roads = {}
        for fact in self.static:
            predicate, *args = get_parts(fact)
            if predicate == "road":
                l1, l2 = args
                if l1 not in self.roads:
                    self.roads[l1] = set()
                if l2 not in self.roads:
                    self.roads[l2] = set()
                self.roads[l1].add(l2)
                self.roads[l2].add(l1)

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate in ["at", "in"]:
                obj, location = args
                current_locations[obj] = location

        total_cost = 0

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package.
            current_location = current_locations[package]

            # Check if the package is inside a vehicle.
            in_vehicle = current_location.startswith("v")

            if in_vehicle:
                # If inside a vehicle, get the vehicle's location.
                vehicle_location = current_locations[current_location]
                current_location = vehicle_location

            # If the package is already at its goal, no cost is added.
            if current_location == goal_location:
                continue

            # Compute the shortest path between current and goal locations.
            path_length = self.shortest_path(current_location, goal_location)

            # Estimate the number of actions:
            # 1. Pick up the package (if not already in a vehicle).
            if not in_vehicle:
                total_cost += 1

            # 2. Drive to the goal location.
            total_cost += path_length

            # 3. Drop the package at the goal location.
            total_cost += 1

        return total_cost

    def shortest_path(self, start, goal):
        """
        Compute the shortest path length between two locations using BFS.

        @param start: The starting location.
        @param goal: The goal location.
        @return: The number of road segments in the shortest path.
        """
        if start == goal:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == goal:
                return distance

            visited.add(current)
            for neighbor in self.roads.get(current, []):
                if neighbor not in visited:
                    queue.append((neighbor, distance + 1))

        # If no path is found, return a large number (infinite cost).
        return float('inf')
