from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following factors:
    - The number of loose nuts that still need to be tightened.
    - The distance the man must travel to collect spanners and reach the nuts.
    - Whether the man is carrying a usable spanner.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - Each spanner can be used only once to tighten a nut.
    - The man must walk to the location of a spanner to pick it up.
    - The man must walk to the location of a nut to tighten it.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Build a graph of locations using the static `link` facts to compute distances.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the number of loose nuts that still need to be tightened.
    2. Check if the man is carrying a usable spanner:
       - If not, estimate the distance to the nearest spanner and add the cost of picking it up.
    3. Estimate the distance from the man's current location to the location of each loose nut.
    4. Sum the costs of walking, picking up spanners, and tightening nuts.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Static facts (links between locations).

        # Build a graph of locations using the static `link` facts.
        self.location_graph = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                loc1, loc2 = get_parts(fact)[1], get_parts(fact)[2]
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Count the number of loose nuts that still need to be tightened.
        loose_nuts = sum(1 for fact in state if match(fact, "loose", "*"))

        # If there are no loose nuts, the heuristic is 0 (goal state).
        if loose_nuts == 0:
            return 0

        # Find the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                man_location = get_parts(fact)[2]
                break

        # Check if the man is carrying a usable spanner.
        carrying_spanner = any(match(fact, "carrying", "bob", "*") for fact in state)
        usable_spanner = any(match(fact, "usable", "*") for fact in state)

        # If not carrying a usable spanner, find the nearest spanner.
        if not carrying_spanner or not usable_spanner:
            # Find all spanner locations.
            spanner_locations = set()
            for fact in state:
                if match(fact, "at", "spanner*", "*"):
                    spanner_locations.add(get_parts(fact)[2])

            # Compute the distance to the nearest spanner.
            min_distance = float('inf')
            for spanner_loc in spanner_locations:
                distance = self._compute_distance(man_location, spanner_loc)
                if distance < min_distance:
                    min_distance = distance

            # Add the cost of picking up the spanner.
            total_cost = min_distance + 1  # 1 action to pick up the spanner.
        else:
            total_cost = 0

        # Add the cost of walking to each loose nut and tightening it.
        for fact in state:
            if match(fact, "loose", "*"):
                nut_loc = get_parts(fact)[1]
                distance = self._compute_distance(man_location, nut_loc)
                total_cost += distance + 1  # 1 action to tighten the nut.

        return total_cost

    def _compute_distance(self, start, end):
        """
        Compute the shortest path distance between two locations using BFS.

        - `start`: The starting location.
        - `end`: The target location.
        - Returns the number of steps (distance) between the locations.
        """
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            visited.add(current)
            for neighbor in self.location_graph.get(current, []):
                if neighbor not in visited:
                    queue.append((neighbor, distance + 1))

        return float('inf')  # If no path exists (should not happen in valid instances).
