from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args)


class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following:
    - The man must pick up a usable spanner if not already carrying one.
    - The man must walk to the location of each loose nut.
    - The man must tighten each loose nut using a usable spanner.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - A spanner becomes unusable after tightening a nut.
    - The man must walk to the location of a spanner to pick it up.
    - The man must walk to the location of a nut to tighten it.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Build a graph of locations using the static `link` facts to compute distances.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the man and the spanners.
    2. Identify the locations of all loose nuts.
    3. If the man is not carrying a usable spanner:
       - Find the nearest usable spanner.
       - Add the cost of walking to the spanner and picking it up.
    4. For each loose nut:
       - Add the cost of walking to the nut's location.
       - Add the cost of tightening the nut (1 action).
    5. Sum the total cost of all actions.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Static facts (e.g., links between locations).

        # Build a graph of locations using the static `link` facts.
        self.location_graph = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Extract the current location of the man.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                man_location = get_parts(fact)[2]
                break

        # Extract the current spanner being carried (if any).
        carrying_spanner = None
        for fact in state:
            if match(fact, "carrying", "bob", "*"):
                carrying_spanner = get_parts(fact)[2]
                break

        # Check if the carried spanner is usable.
        is_spanner_usable = False
        if carrying_spanner:
            for fact in state:
                if match(fact, "usable", carrying_spanner):
                    is_spanner_usable = True
                    break

        # Extract the locations of all loose nuts.
        loose_nuts = []
        for fact in state:
            if match(fact, "loose", "*"):
                nut = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", nut, "*"):
                        loose_nuts.append((nut, get_parts(loc_fact)[2]))
                        break

        # If no loose nuts, return 0 (goal state).
        if not loose_nuts:
            return 0

        # Compute the cost of picking up a usable spanner if necessary.
        total_cost = 0
        if not is_spanner_usable:
            # Find the nearest usable spanner.
            usable_spanners = []
            for fact in state:
                if match(fact, "usable", "*"):
                    spanner = get_parts(fact)[1]
                    for loc_fact in state:
                        if match(loc_fact, "at", spanner, "*"):
                            usable_spanners.append((spanner, get_parts(loc_fact)[2]))
                            break

            if usable_spanners:
                # Compute the distance to the nearest usable spanner.
                min_distance = float("inf")
                nearest_spanner_location = None
                for spanner, location in usable_spanners:
                    distance = self._compute_distance(man_location, location)
                    if distance < min_distance:
                        min_distance = distance
                        nearest_spanner_location = location

                # Add the cost of walking to the spanner and picking it up.
                total_cost += min_distance + 1

        # Compute the cost of tightening all loose nuts.
        for nut, nut_location in loose_nuts:
            # Add the cost of walking to the nut's location.
            total_cost += self._compute_distance(man_location, nut_location)
            # Add the cost of tightening the nut.
            total_cost += 1

        return total_cost

    def _compute_distance(self, start, end):
        """
        Compute the shortest path distance between two locations using BFS.

        Args:
            start (str): The starting location.
            end (str): The target location.

        Returns:
            int: The number of steps required to move from `start` to `end`.
        """
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            visited.add(current)
            for neighbor in self.location_graph.get(current, []):
                if neighbor not in visited:
                    queue.append((neighbor, distance + 1))

        return float("inf")  # If no path exists.
