from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following:
    - The man must pick up a usable spanner if not already carrying one.
    - The man must walk to the location of each loose nut.
    - The man must tighten each loose nut using a usable spanner.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - A spanner can be used only once (it becomes unusable after tightening a nut).
    - The man must walk to the location of each nut to tighten it.
    - The man must walk to the location of a spanner to pick it up.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Identify all loose nuts and their locations.
    - Identify all usable spanners and their locations.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the man.
    2. Check if the man is carrying a usable spanner:
       - If not, add the cost of walking to the nearest spanner and picking it up.
    3. For each loose nut:
       - Add the cost of walking to the nut's location.
       - Add the cost of tightening the nut (1 action).
    4. If the man is not at the starting location of the next action, add the cost of walking to the required location.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (links between locations).

        # Extract all links between locations.
        self.links = {}
        for fact in self.static:
            if match(fact, "link", "*", "*"):
                parts = get_parts(fact)
                loc1, loc2 = parts[1], parts[2]
                if loc1 not in self.links:
                    self.links[loc1] = set()
                if loc2 not in self.links:
                    self.links[loc2] = set()
                self.links[loc1].add(loc2)
                self.links[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Identify the current location of the man.
        man_location = None
        for fact in state:
            if match(fact, "at", "bob", "*"):
                man_location = get_parts(fact)[2]
                break

        # Check if the man is carrying a usable spanner.
        carrying_spanner = False
        for fact in state:
            if match(fact, "carrying", "bob", "*"):
                spanner = get_parts(fact)[2]
                if match(f"(usable {spanner})", *state):
                    carrying_spanner = True
                    break

        # Identify all loose nuts and their locations.
        loose_nuts = []
        for fact in state:
            if match(fact, "loose", "*"):
                nut = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", nut, "*"):
                        nut_location = get_parts(loc_fact)[2]
                        loose_nuts.append((nut, nut_location))
                        break

        # Initialize the heuristic cost.
        total_cost = 0

        # If the man is not carrying a usable spanner, add the cost of picking one up.
        if not carrying_spanner:
            # Find the nearest usable spanner.
            nearest_spanner_location = None
            min_distance = float('inf')
            for fact in state:
                if match(fact, "usable", "*"):
                    spanner = get_parts(fact)[1]
                    for loc_fact in state:
                        if match(loc_fact, "at", spanner, "*"):
                            spanner_location = get_parts(loc_fact)[2]
                            distance = self._compute_distance(man_location, spanner_location)
                            if distance < min_distance:
                                min_distance = distance
                                nearest_spanner_location = spanner_location
                            break

            if nearest_spanner_location:
                total_cost += min_distance  # Walk to the spanner.
                total_cost += 1  # Pick up the spanner.
                man_location = nearest_spanner_location  # Update man's location.

        # Add the cost of walking to each loose nut and tightening it.
        for nut, nut_location in loose_nuts:
            distance = self._compute_distance(man_location, nut_location)
            total_cost += distance  # Walk to the nut.
            total_cost += 1  # Tighten the nut.
            man_location = nut_location  # Update man's location.

        return total_cost

    def _compute_distance(self, start, end):
        """
        Compute the minimum number of walk actions required to move from `start` to `end`.
        Uses a simple BFS to find the shortest path.
        """
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.links.get(current, []):
                queue.append((neighbor, distance + 1))

        return float('inf')  # If no path exists (should not happen in valid instances).
