from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the current state of packages, vehicles, and the road network to compute an efficient estimate.

    # Assumptions
    - Packages can be either on the ground or inside a vehicle.
    - Vehicles can carry multiple packages, but their capacity is limited by size constraints.
    - The road network is bidirectional, meaning roads are two-way.

    # Heuristic Initialization
    - Extract goal locations for each package from the task's goal conditions.
    - Extract the road network and capacity constraints from the static facts.
    - Map each location to its connected locations using the road network.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each package, determine its current location and goal location.
    2. If the package is already at its goal location, no action is required.
    3. If the package is inside a vehicle, determine the vehicle's current location.
    4. Compute the shortest path (number of road segments) from the package's current location to its goal location.
    5. If the package is not in a vehicle, estimate the number of actions required to load it into a vehicle and transport it to the goal.
    6. Sum the estimated actions for all packages to compute the total heuristic value.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each package.
        - Road network and capacity constraints from static facts.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

        # Build the road network as a dictionary mapping locations to their connected locations.
        self.road_network = {}
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "road":
                l1, l2 = args
                if l1 not in self.road_network:
                    self.road_network[l1] = set()
                if l2 not in self.road_network:
                    self.road_network[l2] = set()
                self.road_network[l1].add(l2)
                self.road_network[l2].add(l1)

    def __call__(self, node):
        """
        Compute an estimate of the minimal number of required actions to reach the goal state.
        """
        state = node.state  # Current world state.

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at":
                obj, location = args
                current_locations[obj] = location
            elif predicate == "in":
                package, vehicle = args
                current_locations[package] = vehicle

        total_cost = 0  # Initialize action cost counter.

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package (could be a location or a vehicle).
            current_location = current_locations.get(package, None)

            if current_location is None:
                continue  # Package not found in state (should not happen in valid states).

            # If the package is inside a vehicle, find the vehicle's location.
            if current_location.startswith("v"):
                vehicle_location = current_locations.get(current_location, None)
                if vehicle_location is None:
                    continue  # Vehicle not found in state (should not happen in valid states).
                current_location = vehicle_location

            # If the package is already at its goal location, no action is needed.
            if current_location == goal_location:
                continue

            # Compute the shortest path from the current location to the goal location.
            # Use a simple BFS to find the minimum number of road segments.
            visited = set()
            queue = [(current_location, 0)]
            while queue:
                loc, distance = queue.pop(0)
                if loc == goal_location:
                    break
                if loc in visited:
                    continue
                visited.add(loc)
                for neighbor in self.road_network.get(loc, []):
                    queue.append((neighbor, distance + 1))

            # Add the distance to the total cost.
            total_cost += distance

            # If the package is not in a vehicle, add the cost of loading it into a vehicle.
            if not current_location.startswith("v"):
                total_cost += 1  # Cost of loading into a vehicle.

            # Add the cost of unloading the package at the goal location.
            total_cost += 1  # Cost of unloading from a vehicle.

        return total_cost
