from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the following:
    - The distance between the current location of a package and its goal location.
    - Whether the package is already in a vehicle or needs to be picked up.
    - The capacity constraints of vehicles.

    # Assumptions
    - Vehicles can carry multiple packages, but their capacity is limited.
    - Packages can be picked up and dropped off at any location connected by roads.
    - The heuristic assumes that the shortest path between two locations is used.

    # Heuristic Initialization
    - Extract the goal locations for each package from the task goals.
    - Extract the road network from the static facts to compute distances between locations.
    - Extract the capacity constraints of vehicles from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each package, determine its current location and goal location.
    2. If the package is already at its goal location, no action is needed.
    3. If the package is in a vehicle, determine the vehicle's current location.
    4. Compute the shortest path distance between the package's current location (or the vehicle's location) and the goal location.
    5. If the package is not in a vehicle, add the cost of picking it up.
    6. If the package is in a vehicle, add the cost of dropping it off.
    7. Sum the distances and actions for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each package.
        - Road network from static facts.
        - Capacity constraints of vehicles.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract road network from static facts.
        self.roads = {}
        for fact in static_facts:
            if match(fact, "road", "*", "*"):
                _, l1, l2 = get_parts(fact)
                if l1 not in self.roads:
                    self.roads[l1] = set()
                self.roads[l1].add(l2)

        # Extract capacity constraints of vehicles.
        self.capacity = {}
        for fact in static_facts:
            if match(fact, "capacity", "*", "*"):
                _, vehicle, size = get_parts(fact)
                self.capacity[vehicle] = size

        # Store goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate in ["at", "in"]:  # Track both direct location and inside vehicle.
                obj, location = args
                current_locations[obj] = location

        total_cost = 0  # Initialize action cost counter.

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package (could be a location or a vehicle).
            current_location = current_locations[package]

            # Check if the package is inside a vehicle.
            in_vehicle = current_location.startswith("v")

            if in_vehicle:
                # Retrieve the physical location of the vehicle.
                current_location = current_locations[current_location]

            # If the package is already at its goal location, no action is needed.
            if current_location == goal_location:
                continue

            # Compute the shortest path distance between current and goal locations.
            distance = self.shortest_path_distance(current_location, goal_location)

            # Add the cost of moving the package.
            total_cost += distance

            # If the package is not in a vehicle, add the cost of picking it up.
            if not in_vehicle:
                total_cost += 1  # Cost of pick-up action.

            # If the package is in a vehicle, add the cost of dropping it off.
            if in_vehicle:
                total_cost += 1  # Cost of drop action.

        return total_cost

    def shortest_path_distance(self, start, goal):
        """
        Compute the shortest path distance between two locations using BFS.

        @param start: The starting location.
        @param goal: The goal location.
        @return: The number of steps in the shortest path.
        """
        if start == goal:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == goal:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.roads.get(current, []):
                queue.append((neighbor, distance + 1))

        return float('inf')  # If no path exists, return infinity.
