from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following:
    - The man must pick up a usable spanner if not already carrying one.
    - The man must walk to the location of each loose nut.
    - The man must tighten each loose nut using a usable spanner.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - A spanner becomes unusable after tightening a nut.
    - The man must walk to the location of a spanner to pick it up.
    - The man must walk to the location of a nut to tighten it.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Create a mapping of locations to their connected locations using the static `link` facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the number of loose nuts that need to be tightened.
    2. Check if the man is carrying a usable spanner:
       - If not, add the cost of picking up a usable spanner.
    3. For each loose nut:
       - Calculate the shortest path from the man's current location to the nut's location.
       - Add the cost of walking along this path.
       - Add the cost of tightening the nut (1 action).
    4. If the man is not carrying a usable spanner, add the cost of walking to the nearest usable spanner and picking it up.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Static facts.

        # Create a graph of locations using the `link` facts.
        self.location_graph = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Count the number of loose nuts that need to be tightened.
        loose_nuts = sum(1 for fact in state if match(fact, "loose", "*"))

        # Check if the man is carrying a usable spanner.
        carrying_spanner = any(match(fact, "carrying", "*", "*") for fact in state)
        usable_spanner = any(match(fact, "usable", "*") for fact in state)

        # Get the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "*", "*"):
                _, obj, loc = get_parts(fact)
                if obj == "bob":  # Assuming the man is named "bob".
                    man_location = loc
                    break

        # Initialize the heuristic cost.
        total_cost = 0

        # If the man is not carrying a usable spanner, add the cost of picking one up.
        if not (carrying_spanner and usable_spanner):
            # Find the nearest usable spanner.
            nearest_spanner_location = None
            min_distance = float("inf")
            for fact in state:
                if match(fact, "at", "*", "*"):
                    _, obj, loc = get_parts(fact)
                    if obj.startswith("spanner") and match(fact, "usable", "*"):
                        # Calculate the distance from the man's location to the spanner's location.
                        distance = self._shortest_path_length(man_location, loc)
                        if distance < min_distance:
                            min_distance = distance
                            nearest_spanner_location = loc

            if nearest_spanner_location:
                total_cost += min_distance  # Walk to the spanner.
                total_cost += 1  # Pick up the spanner.

        # Add the cost of walking to each loose nut and tightening it.
        for fact in state:
            if match(fact, "loose", "*"):
                _, nut = get_parts(fact)
                # Find the location of the nut.
                nut_location = None
                for f in state:
                    if match(f, "at", nut, "*"):
                        _, _, loc = get_parts(f)
                        nut_location = loc
                        break
                if nut_location:
                    # Calculate the distance from the man's location to the nut's location.
                    distance = self._shortest_path_length(man_location, nut_location)
                    total_cost += distance  # Walk to the nut.
                    total_cost += 1  # Tighten the nut.

        return total_cost

    def _shortest_path_length(self, start, end):
        """Compute the shortest path length between two locations using BFS."""
        if start == end:
            return 0
        visited = set()
        queue = [(start, 0)]
        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.location_graph.get(current, []):
                queue.append((neighbor, distance + 1))
        return float("inf")  # If no path exists, return infinity.
