from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following:
    - The man must pick up a usable spanner.
    - The man must walk to the location of each loose nut.
    - The man must tighten each loose nut using a spanner.

    # Assumptions:
    - The man can carry multiple spanners at once.
    - A spanner can only be used once (it becomes unusable after tightening a nut).
    - The man must walk to the location of each nut to tighten it.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations) from the task.
    - Build a graph of locations using the static `link` facts to compute distances between locations.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify all loose nuts that need to be tightened.
    2. Determine the man's current location.
    3. Check if the man is carrying any usable spanners.
    4. If not, estimate the cost to pick up a usable spanner:
       - Find the nearest spanner to the man's current location.
       - Add the walking distance to reach the spanner.
       - Add the action cost to pick up the spanner.
    5. For each loose nut:
       - Estimate the walking distance from the man's current location to the nut's location.
       - Add the action cost to tighten the nut.
    6. Sum the total cost of picking up spanners and tightening all loose nuts.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Build a graph of locations using the static `link` facts.
        self.location_graph = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Identify all loose nuts that need to be tightened.
        loose_nuts = {get_parts(fact)[1] for fact in state if match(fact, "loose", "*")}

        # If there are no loose nuts, the heuristic is 0 (goal state).
        if not loose_nuts:
            return 0

        # Determine the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                man_location = get_parts(fact)[2]
                break
        if not man_location:
            return float('inf')  # Man has no location, state is invalid.

        # Check if the man is carrying any usable spanners.
        usable_spanners = set()
        for fact in state:
            if match(fact, "carrying", "bob", "*"):
                spanner = get_parts(fact)[2]
                if any(match(f, "usable", spanner) for f in state):
                    usable_spanners.add(spanner)

        # If no usable spanners are being carried, estimate the cost to pick one up.
        if not usable_spanners:
            # Find the nearest spanner to the man's current location.
            spanner_locations = {
                get_parts(fact)[2]: get_parts(fact)[1]
                for fact in state
                if match(fact, "at", "*", "*") and match(fact, "at", "spanner*", "*")
            }
            if not spanner_locations:
                return float('inf')  # No spanners available.

            # Compute the shortest path to the nearest spanner.
            min_distance = float('inf')
            for spanner, location in spanner_locations.items():
                distance = self._shortest_path(man_location, location)
                if distance < min_distance:
                    min_distance = distance

            # Add the cost to walk to the spanner and pick it up.
            total_cost = min_distance + 1  # 1 action to pick up the spanner.
        else:
            total_cost = 0

        # For each loose nut, estimate the walking distance and tightening cost.
        for nut in loose_nuts:
            # Find the nut's location.
            nut_location = None
            for fact in state:
                if match(fact, "at", nut, "*"):
                    nut_location = get_parts(fact)[2]
                    break
            if not nut_location:
                return float('inf')  # Nut has no location, state is invalid.

            # Compute the shortest path from the man's current location to the nut.
            distance = self._shortest_path(man_location, nut_location)
            total_cost += distance + 1  # 1 action to tighten the nut.

        return total_cost

    def _shortest_path(self, start, end):
        """Compute the shortest path between two locations using BFS."""
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            if current in visited:
                continue
            visited.add(current)

            for neighbor in self.location_graph.get(current, []):
                queue.append((neighbor, distance + 1))

        return float('inf')  # No path exists.
