from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the current state of packages, vehicles, and the road network to compute an efficient estimate.

    # Assumptions
    - Packages can be either on the ground or inside a vehicle.
    - Vehicles can carry multiple packages, but their capacity is limited.
    - The road network is bidirectional, and vehicles can move freely between connected locations.

    # Heuristic Initialization
    - Extract goal locations for each package from the task's goal conditions.
    - Extract the road network and capacity constraints from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each package, determine its current location and whether it is inside a vehicle.
    2. If the package is already at its goal location, no action is required.
    3. If the package is not at its goal location:
       - If it is inside a vehicle, estimate the number of actions required to drive the vehicle to the goal location and unload the package.
       - If it is on the ground, estimate the number of actions required to load it into a vehicle, drive the vehicle to the goal location, and unload the package.
    4. Sum the estimated actions for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each package.
        - Road network and capacity constraints from static facts.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract road network from static facts.
        self.road_network = {}
        for fact in static_facts:
            if match(fact, "road", "*", "*"):
                parts = get_parts(fact)
                loc1, loc2 = parts[1], parts[2]
                if loc1 not in self.road_network:
                    self.road_network[loc1] = set()
                if loc2 not in self.road_network:
                    self.road_network[loc2] = set()
                self.road_network[loc1].add(loc2)
                self.road_network[loc2].add(loc1)

        # Extract capacity constraints from static facts.
        self.capacity_predecessors = {}
        for fact in static_facts:
            if match(fact, "capacity-predecessor", "*", "*"):
                parts = get_parts(fact)
                s1, s2 = parts[1], parts[2]
                self.capacity_predecessors[s2] = s1

        # Store goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate in ["at", "in"]:  # Track both direct location and inside vehicle.
                obj, location = args
                current_locations[obj] = location

        total_cost = 0  # Initialize action cost counter.

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package (could be a location or a vehicle).
            current_location = current_locations[package]

            # Check if the package is inside a vehicle.
            in_vehicle = current_location.startswith("v")

            if in_vehicle:
                # Retrieve the physical location of the vehicle.
                vehicle_location = current_locations[current_location]
                current_location = vehicle_location

            # If the package is already at its goal location, no action is required.
            if current_location == goal_location:
                continue

            # Estimate the number of drive actions required to reach the goal location.
            # This is a simple heuristic that assumes the shortest path is used.
            # In practice, a more accurate pathfinding algorithm could be used.
            drive_cost = self.estimate_drive_cost(current_location, goal_location)

            # Add the cost of loading and unloading the package.
            if in_vehicle:
                total_cost += drive_cost + 1  # Unload action.
            else:
                total_cost += 1 + drive_cost + 1  # Load, drive, and unload actions.

        return total_cost

    def estimate_drive_cost(self, start, goal):
        """
        Estimate the number of drive actions required to move from `start` to `goal`.
        This is a simple heuristic that assumes the shortest path is used.
        """
        if start == goal:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, cost = queue.pop(0)
            if current == goal:
                return cost
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.road_network.get(current, []):
                queue.append((neighbor, cost + 1))

        return float('inf')  # If no path is found, return infinity.
