from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following steps:
    - Walking to the location of a spanner.
    - Picking up the spanner.
    - Walking to the location of a loose nut.
    - Tightening the nut.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - Each spanner can be used only once.
    - The man must walk to the location of a spanner or nut to interact with it.
    - The goal is to tighten all loose nuts.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (link relationships) from the task.
    - Build a graph of locations using the `link` facts to compute distances between locations.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify all loose nuts that need to be tightened.
    2. For each loose nut:
       - If the man is not at the nut's location, estimate the walking distance to reach it.
       - If the man is not carrying a usable spanner, estimate the walking distance to the nearest spanner and the cost of picking it up.
    3. Sum the estimated costs for all loose nuts, considering the man's current location and the spanner's availability.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Build a graph of locations using the `link` facts.
        self.location_graph = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state  # Current world state.

        # Identify all loose nuts.
        loose_nuts = set()
        for fact in state:
            if match(fact, "loose", "*"):
                _, nut = get_parts(fact)
                loose_nuts.add(nut)

        # If no loose nuts, heuristic is 0 (goal state).
        if not loose_nuts:
            return 0

        # Get the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                _, _, loc = get_parts(fact)
                man_location = loc
                break
        assert man_location is not None, "Man's location not found in state."

        # Get the spanners the man is carrying.
        carried_spanners = set()
        for fact in state:
            if match(fact, "carrying", "bob", "*"):
                _, _, spanner = get_parts(fact)
                carried_spanners.add(spanner)

        # Get the locations of all usable spanners.
        usable_spanners = set()
        for fact in state:
            if match(fact, "usable", "*"):
                _, spanner = get_parts(fact)
                usable_spanners.add(spanner)

        # Get the locations of all spanners.
        spanner_locations = {}
        for fact in state:
            if match(fact, "at", "*", "*"):
                _, obj, loc = get_parts(fact)
                if obj in usable_spanners:
                    spanner_locations[obj] = loc

        # Compute the heuristic cost.
        total_cost = 0

        for nut in loose_nuts:
            # Find the nut's location.
            nut_location = None
            for fact in state:
                if match(fact, "at", nut, "*"):
                    _, _, loc = get_parts(fact)
                    nut_location = loc
                    break
            assert nut_location is not None, f"Nut {nut} location not found."

            # If the man is not at the nut's location, add walking cost.
            if man_location != nut_location:
                total_cost += self._compute_distance(man_location, nut_location)

            # If the man is not carrying a usable spanner, add cost to pick one up.
            if not carried_spanners:
                # Find the nearest usable spanner.
                nearest_spanner = None
                min_distance = float('inf')
                for spanner, loc in spanner_locations.items():
                    distance = self._compute_distance(man_location, loc)
                    if distance < min_distance:
                        min_distance = distance
                        nearest_spanner = spanner

                # Add walking cost to the spanner and picking cost.
                total_cost += min_distance + 1  # 1 action to pick up the spanner.

                # Update the man's location to the spanner's location.
                man_location = spanner_locations[nearest_spanner]

            # Add cost to tighten the nut.
            total_cost += 1  # 1 action to tighten the nut.

        return total_cost

    def _compute_distance(self, start, end):
        """
        Compute the shortest path distance between two locations using BFS.

        - `start`: The starting location.
        - `end`: The target location.
        - Returns the number of steps (actions) required to move from `start` to `end`.
        """
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            visited.add(current)
            for neighbor in self.location_graph.get(current, []):
                if neighbor not in visited:
                    queue.append((neighbor, distance + 1))

        return float('inf')  # If no path exists (should not happen in valid instances).
