from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following:
    - The man must walk to the location of each spanner and pick it up.
    - The man must walk to the location of each loose nut and tighten it using a usable spanner.

    # Assumptions:
    - The man can carry multiple spanners at once.
    - Each spanner can be used only once.
    - The man must walk between locations to pick up spanners and tighten nuts.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations) from the task.
    - Build a graph of locations using the static `link` facts to compute distances between locations.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify all loose nuts that need to be tightened.
    2. For each loose nut:
       - Compute the distance from the man's current location to the nut's location.
       - If the man is not carrying a usable spanner, compute the distance to the nearest usable spanner and add it to the total cost.
    3. Sum the distances and add the cost of picking up spanners and tightening nuts.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Build a graph of locations using the static `link` facts.
        self.location_graph = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Identify all loose nuts.
        loose_nuts = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in state
            if match(fact, "loose", "*")
        }

        # Identify the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                man_location = get_parts(fact)[2]
                break

        # Identify usable spanners and their locations.
        usable_spanners = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in state
            if match(fact, "usable", "*")
        }

        # Identify spanners the man is already carrying.
        carried_spanners = {
            get_parts(fact)[2]
            for fact in state
            if match(fact, "carrying", "bob", "*")
        }

        total_cost = 0  # Initialize the heuristic cost.

        for nut, nut_location in loose_nuts.items():
            # Compute the distance from the man's current location to the nut's location.
            distance_to_nut = self._shortest_path(man_location, nut_location)
            total_cost += distance_to_nut

            # If the man is not carrying a usable spanner, find the nearest one.
            if not any(spanner in carried_spanners for spanner in usable_spanners):
                nearest_spanner_location = min(
                    (self._shortest_path(man_location, spanner_location), spanner_location)
                    for spanner, spanner_location in usable_spanners.items()
                )[1]
                distance_to_spanner = self._shortest_path(man_location, nearest_spanner_location)
                total_cost += distance_to_spanner
                man_location = nearest_spanner_location  # Update man's location after picking up the spanner.

            # Add the cost of tightening the nut.
            total_cost += 1  # One action to tighten the nut.

        return total_cost

    def _shortest_path(self, start, end):
        """Compute the shortest path between two locations using BFS."""
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.location_graph.get(current, []):
                queue.append((neighbor, distance + 1))

        return float("inf")  # If no path exists, return infinity.
