from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following steps:
    - Walking to the location of a spanner.
    - Picking up the spanner.
    - Walking to the location of a loose nut.
    - Tightening the nut.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - Each spanner can be used only once.
    - The goal is to tighten all loose nuts.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts).
    - Extract static facts (links between locations).
    - Identify the locations of all spanners and nuts.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the number of loose nuts that need to be tightened.
    2. Determine the current location of the man.
    3. For each loose nut:
       - Find the nearest spanner that can be used to tighten it.
       - Calculate the walking distance from the man's current location to the spanner's location.
       - Calculate the walking distance from the spanner's location to the nut's location.
       - Add the cost of picking up the spanner and tightening the nut.
    4. Sum the costs for all loose nuts to get the total heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts.

        # Extract links between locations.
        self.links = {}
        for fact in self.static:
            if match(fact, "link", "*", "*"):
                parts = get_parts(fact)
                loc1, loc2 = parts[1], parts[2]
                if loc1 not in self.links:
                    self.links[loc1] = set()
                if loc2 not in self.links:
                    self.links[loc2] = set()
                self.links[loc1].add(loc2)
                self.links[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Count the number of loose nuts.
        loose_nuts = sum(1 for fact in state if match(fact, "loose", "*"))

        # If no loose nuts, heuristic is 0.
        if loose_nuts == 0:
            return 0

        # Get the man's current location.
        man_location = None
        for fact in state:
            if match(fact, "at", "*", "*"):
                parts = get_parts(fact)
                if parts[1] == "bob":  # Assuming the man is named "bob".
                    man_location = parts[2]
                    break

        # Get the locations of all spanners and nuts.
        spanner_locations = {}
        nut_locations = {}
        for fact in state:
            if match(fact, "at", "*", "*"):
                parts = get_parts(fact)
                obj, loc = parts[1], parts[2]
                if obj.startswith("spanner"):
                    spanner_locations[obj] = loc
                elif obj.startswith("nut"):
                    nut_locations[obj] = loc

        # Calculate the heuristic value.
        total_cost = 0

        for nut, nut_loc in nut_locations.items():
            if match(f"(loose {nut})", *state):
                # Find the nearest spanner.
                min_cost = float('inf')
                for spanner, spanner_loc in spanner_locations.items():
                    if match(f"(usable {spanner})", *state):
                        # Calculate walking distance from man to spanner.
                        walk_cost_man_to_spanner = self._calculate_walking_distance(man_location, spanner_loc)
                        # Calculate walking distance from spanner to nut.
                        walk_cost_spanner_to_nut = self._calculate_walking_distance(spanner_loc, nut_loc)
                        # Total cost for this spanner.
                        total_spanner_cost = walk_cost_man_to_spanner + 1 + walk_cost_spanner_to_nut + 1
                        if total_spanner_cost < min_cost:
                            min_cost = total_spanner_cost
                total_cost += min_cost

        return total_cost

    def _calculate_walking_distance(self, start, end):
        """Calculate the walking distance between two locations using BFS."""
        if start == end:
            return 0

        visited = set()
        queue = [(start, 0)]

        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.links.get(current, []):
                queue.append((neighbor, distance + 1))

        return float('inf')  # If no path found, return infinity.
