from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the following:
    - The distance between the current location of a package and its goal location.
    - Whether a vehicle is available to transport the package.
    - The capacity constraints of vehicles.

    # Assumptions
    - Vehicles can carry multiple packages, but their capacity is limited.
    - Packages can only be transported by vehicles that are at the same location.
    - The heuristic assumes that the shortest path between locations is used for driving actions.

    # Heuristic Initialization
    - Extract the goal locations for each package from the task goals.
    - Extract the road network from the static facts to compute distances between locations.
    - Extract the capacity constraints of vehicles from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each package, determine its current location and goal location.
    2. Compute the shortest path distance between the current location and the goal location using the road network.
    3. If the package is not already in a vehicle, estimate the number of actions required to load it into a vehicle.
    4. If the package is in a vehicle, estimate the number of actions required to drive the vehicle to the goal location and unload the package.
    5. Sum the estimated actions for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals
        self.static = task.static

        # Extract goal locations for each package
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

        # Extract road network
        self.roads = {}
        for fact in self.static:
            predicate, *args = get_parts(fact)
            if predicate == "road":
                l1, l2 = args
                if l1 not in self.roads:
                    self.roads[l1] = set()
                self.roads[l1].add(l2)

        # Extract vehicle capacities
        self.capacities = {}
        for fact in self.static:
            predicate, *args = get_parts(fact)
            if predicate == "capacity":
                vehicle, size = args
                self.capacities[vehicle] = size

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state

        # Track current locations of packages and vehicles
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at":
                obj, location = args
                current_locations[obj] = location
            elif predicate == "in":
                package, vehicle = args
                current_locations[package] = vehicle

        total_cost = 0

        for package, goal_location in self.goal_locations.items():
            current_location = current_locations.get(package, None)

            if current_location is None:
                continue  # Package not in state, skip

            # If package is in a vehicle, get the vehicle's location
            if current_location.startswith("v"):
                vehicle = current_location
                current_location = current_locations.get(vehicle, None)
                if current_location is None:
                    continue  # Vehicle not in state, skip

            # Compute shortest path distance
            distance = self.shortest_path_distance(current_location, goal_location)
            if distance is None:
                continue  # No path, skip

            # Estimate actions: drive to goal, unload package
            total_cost += distance + 1  # +1 for unload action

        return total_cost

    def shortest_path_distance(self, start, goal):
        """Compute the shortest path distance between two locations using BFS."""
        if start == goal:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == goal:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.roads.get(current, []):
                queue.append((neighbor, distance + 1))

        return None  # No path found
