from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following steps:
    - Walking to the location of a spanner.
    - Picking up the spanner.
    - Walking to the location of a loose nut.
    - Tightening the nut.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - Each spanner can be used only once.
    - The man must walk to the location of a spanner or nut to interact with it.
    - The goal is to tighten all loose nuts.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Build a graph of locations using the static `link` facts to compute distances.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify all loose nuts that need to be tightened.
    2. For each loose nut:
       - If the man is not at the nut's location, estimate the walking distance.
       - If the man is not carrying a usable spanner, estimate the walking distance to the nearest spanner and the cost of picking it up.
    3. Sum the estimated actions for all loose nuts.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts.

        # Build a graph of locations using the static `link` facts.
        self.location_graph = {}
        for fact in self.static:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Identify all loose nuts.
        loose_nuts = {get_parts(fact)[1] for fact in state if match(fact, "loose", "*")}

        # If no loose nuts, heuristic is 0 (goal state).
        if not loose_nuts:
            return 0

        # Get the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                man_location = get_parts(fact)[2]
                break

        # Get the spanners the man is carrying.
        carrying_spanners = {get_parts(fact)[2] for fact in state if match(fact, "carrying", "bob", "*")}

        # Get usable spanners.
        usable_spanners = {get_parts(fact)[1] for fact in state if match(fact, "usable", "*")}

        # Get the locations of all spanners.
        spanner_locations = {}
        for fact in state:
            if match(fact, "at", "*", "*"):
                obj, loc = get_parts(fact)[1], get_parts(fact)[2]
                if obj.startswith("spanner"):
                    spanner_locations[obj] = loc

        # Compute the heuristic cost.
        total_cost = 0

        for nut in loose_nuts:
            # Get the nut's location.
            nut_location = None
            for fact in state:
                if match(fact, "at", nut, "*"):
                    nut_location = get_parts(fact)[2]
                    break

            # If the man is not at the nut's location, add walking cost.
            if man_location != nut_location:
                total_cost += self._compute_walking_distance(man_location, nut_location)

            # If the man is not carrying a usable spanner, add cost to pick one up.
            if not carrying_spanners or not usable_spanners:
                # Find the nearest usable spanner.
                nearest_spanner, nearest_distance = None, float('inf')
                for spanner, loc in spanner_locations.items():
                    if spanner in usable_spanners:
                        distance = self._compute_walking_distance(man_location, loc)
                        if distance < nearest_distance:
                            nearest_spanner, nearest_distance = spanner, distance

                if nearest_spanner:
                    total_cost += nearest_distance  # Walk to the spanner.
                    total_cost += 1  # Pick up the spanner.
                    carrying_spanners.add(nearest_spanner)
                    usable_spanners.discard(nearest_spanner)

            # Add cost to tighten the nut.
            total_cost += 1

        return total_cost

    def _compute_walking_distance(self, start, end):
        """
        Compute the minimum number of walking actions required to move from `start` to `end`.
        Uses a simple BFS to find the shortest path in the location graph.
        """
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            visited.add(current)
            for neighbor in self.location_graph.get(current, []):
                if neighbor not in visited:
                    queue.append((neighbor, distance + 1))

        return float('inf')  # If no path exists (should not happen in valid instances).
