from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at p1 l1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class TransportHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Transport domain.

    # Summary
    This heuristic estimates the number of actions required to transport all packages to their goal locations. It considers the current state of packages (whether they are in a vehicle or at a location) and the road network to compute the minimal number of actions needed.

    # Assumptions
    - Packages can be at a location or inside a vehicle.
    - Vehicles can move between connected locations via roads.
    - The capacity of vehicles is respected when picking up or dropping packages.
    - The heuristic does not need to be admissible, so it can overestimate the cost.

    # Heuristic Initialization
    - Extract the goal locations for each package from the task goals.
    - Extract the road network from the static facts to compute distances between locations.
    - Extract the capacity relationships between sizes from the static facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. For each package, determine its current location (either directly or via the vehicle it is in).
    2. Compute the shortest path (number of roads) from the current location to the goal location.
    3. If the package is in a vehicle, account for the actions needed to drop it at the goal location.
    4. If the package is not in a vehicle, account for the actions needed to pick it up and transport it.
    5. Sum the actions required for all packages to get the total heuristic value.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal locations for each package.
        - Road network from static facts.
        - Capacity relationships from static facts.
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract road network from static facts.
        self.roads = {}
        for fact in static_facts:
            if match(fact, "road", "*", "*"):
                _, l1, l2 = get_parts(fact)
                if l1 not in self.roads:
                    self.roads[l1] = set()
                self.roads[l1].add(l2)

        # Extract capacity relationships from static facts.
        self.capacity_predecessors = {}
        for fact in static_facts:
            if match(fact, "capacity-predecessor", "*", "*"):
                _, s1, s2 = get_parts(fact)
                self.capacity_predecessors[s2] = s1

        # Store goal locations for each package.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "at":
                package, location = args
                self.goal_locations[package] = location

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track where packages and vehicles are currently located.
        current_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate in ["at", "in"]:  # Track both direct location and inside vehicle.
                obj, location = args
                current_locations[obj] = location

        total_cost = 0  # Initialize action cost counter.

        for package, goal_location in self.goal_locations.items():
            # Get the current location of the package (could be a location or a vehicle).
            current_location = current_locations[package]

            # Check if the package is inside a vehicle.
            in_vehicle = current_location.startswith("v")

            if in_vehicle:
                # Retrieve the physical location of the vehicle.
                current_location = current_locations[current_location]

            # Compute the shortest path from current location to goal location.
            if current_location == goal_location:
                # Package is already at the goal location.
                continue

            # Use BFS to find the shortest path (number of roads) between locations.
            visited = set()
            queue = [(current_location, 0)]
            while queue:
                loc, dist = queue.pop(0)
                if loc == goal_location:
                    break
                if loc in visited:
                    continue
                visited.add(loc)
                for neighbor in self.roads.get(loc, []):
                    queue.append((neighbor, dist + 1))

            # Add the distance to the total cost.
            total_cost += dist

            # If the package is in a vehicle, add the cost to drop it.
            if in_vehicle:
                total_cost += 1  # Drop action.

            # If the package is not in a vehicle, add the cost to pick it up.
            else:
                total_cost += 1  # Pick-up action.

        return total_cost
