from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the current state of packages, vehicles, and roads, and calculates the minimal number of actions needed to achieve the goal.

    # Assumptions
    - Packages can be at a location or inside a vehicle.
    - Vehicles can move between connected locations via roads.
    - Vehicles have limited capacity, which affects the number of packages they can carry.
    - The heuristic assumes that the most efficient path is taken, but does not guarantee admissibility.

    # Heuristic Initialization
    - Extract goal locations for each package from the task goals.
    - Extract static information about roads and capacity relationships from the task's static facts.
    - Build a graph of locations connected by roads to compute shortest paths.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of each package and its goal location.
    2. Determine if the package is already at its goal location. If so, no action is needed.
    3. If the package is not at its goal:
       - If the package is inside a vehicle, determine the vehicle's current location.
       - Compute the shortest path from the current location to the goal location using the road graph.
       - Estimate the number of `drive` actions required to move the vehicle along this path.
       - If the package is not inside a vehicle, estimate the number of `pick-up` and `drop` actions required to load and unload the package.
    4. Sum the estimated actions for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract road connections and build a graph of locations.
        self.road_graph = {}
        for fact in static_facts:
            if match(fact, "road", "*", "*"):
                _, l1, l2 = get_parts(fact)
                if l1 not in self.road_graph:
                    self.road_graph[l1] = set()
                if l2 not in self.road_graph:
                    self.road_graph[l2] = set()
                self.road_graph[l1].add(l2)
                self.road_graph[l2].add(l1)

        # Extract capacity relationships.
        self.capacity_predecessors = {}
        for fact in static_facts:
            if match(fact, "capacity-predecessor", "*", "*"):
                _, s1, s2 = get_parts(fact)
                self.capacity_predecessors[s2] = s1

        # Store goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate in ["at", "in"]:  # Track both direct location and inside vehicle.
                obj, location = args
                current_locations[obj] = location

        total_cost = 0  # Initialize action cost counter.

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package (could be a location or a vehicle).
            current_location = current_locations[package]

            # Check if the package is inside a vehicle.
            in_vehicle = current_location.startswith("v")

            if in_vehicle:
                # Retrieve the physical location of the vehicle.
                current_location = current_locations[current_location]

            # If the package is already at its goal, no action is needed.
            if current_location == goal_location:
                continue

            # Compute the shortest path from the current location to the goal.
            path = self._shortest_path(current_location, goal_location)
            if not path:
                # If no path exists, return a large number (indicating unsolvable state).
                return float('inf')

            # Estimate the number of `drive` actions required.
            drive_actions = len(path) - 1

            # If the package is not in a vehicle, add `pick-up` and `drop` actions.
            if not in_vehicle:
                total_cost += 1  # `pick-up` action.
                total_cost += 1  # `drop` action.

            total_cost += drive_actions

        return total_cost

    def _shortest_path(self, start, goal):
        """Compute the shortest path between two locations using BFS."""
        if start == goal:
            return [start]

        visited = set()
        queue = [(start, [start])]

        while queue:
            current, path = queue.pop(0)
            if current == goal:
                return path

            if current in visited:
                continue
            visited.add(current)

            for neighbor in self.road_graph.get(current, []):
                queue.append((neighbor, path + [neighbor]))

        return None  # No path found.
