from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the following:
    - The distance between the current location of a package and its goal location.
    - Whether the package is already in a vehicle or needs to be picked up.
    - The capacity constraints of vehicles.

    # Assumptions
    - Vehicles can carry multiple packages, but their capacity is limited.
    - Packages can only be transported by vehicles that are at the same location.
    - The heuristic assumes that the shortest path between locations is used for driving actions.

    # Heuristic Initialization
    - Extract the goal locations for each package from the task goals.
    - Extract the road network from the static facts to compute distances between locations.
    - Extract the capacity constraints of vehicles from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each package, determine its current location and goal location.
    2. If the package is already at its goal location, no actions are needed.
    3. If the package is not at its goal location:
       - If the package is not in a vehicle, estimate the number of actions required to pick it up and transport it to the goal.
       - If the package is in a vehicle, estimate the number of actions required to drive the vehicle to the goal location and drop the package.
    4. Sum the estimated actions for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each package.
        - Road network from static facts.
        - Capacity constraints of vehicles.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

        # Extract the road network.
        self.roads = set()
        for fact in static_facts:
            if match(fact, "road", "*", "*"):
                l1, l2 = get_parts(fact)[1], get_parts(fact)[2]
                self.roads.add((l1, l2))
                self.roads.add((l2, l1))  # Roads are bidirectional.

        # Extract vehicle capacities.
        self.capacities = {}
        for fact in static_facts:
            if match(fact, "capacity", "*", "*"):
                vehicle, size = get_parts(fact)[1], get_parts(fact)[2]
                self.capacities[vehicle] = size

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at":
                obj, location = args
                current_locations[obj] = location
            elif predicate == "in":
                package, vehicle = args
                current_locations[package] = vehicle

        total_cost = 0  # Initialize action cost counter.

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package.
            current_location = current_locations.get(package, None)

            if current_location is None:
                continue  # Package not in state (should not happen).

            # If the package is already at its goal, no cost is added.
            if current_location == goal_location:
                continue

            # If the package is in a vehicle, find the vehicle's location.
            if current_location.startswith("v"):
                vehicle = current_location
                vehicle_location = current_locations.get(vehicle, None)
                if vehicle_location is None:
                    continue  # Vehicle not in state (should not happen).

                # Estimate the number of drive actions to reach the goal.
                drive_cost = self._compute_drive_cost(vehicle_location, goal_location)
                total_cost += drive_cost + 1  # Add 1 for the drop action.

            else:
                # Package is not in a vehicle.
                # Estimate the number of drive actions to reach the package's location.
                drive_cost = self._compute_drive_cost(current_location, goal_location)
                total_cost += drive_cost + 2  # Add 2 for pick-up and drop actions.

        return total_cost

    def _compute_drive_cost(self, start, goal):
        """
        Compute the minimum number of drive actions required to travel from `start` to `goal`.
        This is a simplified version assuming the shortest path is used.
        """
        if start == goal:
            return 0

        # Use a simple BFS to find the shortest path.
        visited = set()
        queue = [(start, 0)]

        while queue:
            current, cost = queue.pop(0)
            if current == goal:
                return cost

            visited.add(current)
            for l1, l2 in self.roads:
                if l1 == current and l2 not in visited:
                    queue.append((l2, cost + 1))

        return float('inf')  # No path found (should not happen in valid instances).
