from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions required to tighten all loose nuts in the Spanner domain.
    It considers the following steps:
    - Walking to the location of a spanner.
    - Picking up the spanner.
    - Walking to the location of a loose nut.
    - Tightening the nut.

    # Assumptions
    - The man can carry only one spanner at a time.
    - Each spanner can be used only once.
    - The man must walk to the location of a spanner or nut to interact with it.
    - The goal is to tighten all loose nuts.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Build a graph of locations using the `link` facts to compute distances between locations.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the man and the locations of all loose nuts.
    2. Identify the locations of all usable spanners.
    3. Compute the shortest path distance from the man's current location to each spanner.
    4. Compute the shortest path distance from each spanner to each loose nut.
    5. For each loose nut, estimate the number of actions required:
       - Walk to the nearest spanner.
       - Pick up the spanner.
       - Walk to the nut.
       - Tighten the nut.
    6. Sum the actions for all loose nuts, ensuring that each spanner is used only once.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Build a graph of locations using the `link` facts.
        self.location_graph = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the number of actions required to tighten all loose nuts."""
        state = node.state

        # Identify the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                man_location = get_parts(fact)[2]
                break
        if not man_location:
            return float("inf")  # Invalid state.

        # Identify all loose nuts and their locations.
        loose_nuts = {}
        for fact in state:
            if match(fact, "loose", "*"):
                nut = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", nut, "*"):
                        loose_nuts[nut] = get_parts(loc_fact)[2]
                        break

        # Identify all usable spanners and their locations.
        usable_spanners = {}
        for fact in state:
            if match(fact, "usable", "*"):
                spanner = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", spanner, "*"):
                        usable_spanners[spanner] = get_parts(loc_fact)[2]
                        break

        # Compute the shortest path distance between two locations using BFS.
        def shortest_path_distance(start, end):
            if start == end:
                return 0
            visited = set()
            queue = [(start, 0)]
            while queue:
                current, dist = queue.pop(0)
                if current == end:
                    return dist
                visited.add(current)
                for neighbor in self.location_graph.get(current, []):
                    if neighbor not in visited:
                        queue.append((neighbor, dist + 1))
            return float("inf")  # No path found.

        total_cost = 0

        # For each loose nut, find the nearest spanner and compute the cost.
        for nut, nut_location in loose_nuts.items():
            min_cost = float("inf")
            best_spanner = None
            for spanner, spanner_location in usable_spanners.items():
                # Cost to walk to the spanner, pick it up, walk to the nut, and tighten it.
                cost = (
                    shortest_path_distance(man_location, spanner_location)
                    + 1  # Pick up spanner.
                    + shortest_path_distance(spanner_location, nut_location)
                    + 1  # Tighten nut.
                )
                if cost < min_cost:
                    min_cost = cost
                    best_spanner = spanner
            if best_spanner:
                total_cost += min_cost
                # Remove the used spanner from the list of usable spanners.
                del usable_spanners[best_spanner]
                # Update the man's location to the nut's location.
                man_location = nut_location
            else:
                return float("inf")  # No usable spanner found for this nut.

        return total_cost
