from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args)


class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following:
    - The man must pick up a usable spanner if not already carrying one.
    - The man must walk to the location of each loose nut.
    - The man must tighten each loose nut using a usable spanner.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - A spanner becomes unusable after tightening a nut.
    - The man must walk to the location of a spanner to pick it up.
    - The man must walk to the location of a nut to tighten it.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Create a mapping of locations to their connected locations using the static `link` facts.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the man.
    2. Identify the current location of all loose nuts.
    3. Identify the current location of all usable spanners.
    4. If the man is not carrying a usable spanner:
       - Compute the distance to the nearest usable spanner.
       - Add the cost of walking to the spanner and picking it up.
    5. For each loose nut:
       - Compute the distance from the man's current location (or the spanner's location) to the nut.
       - Add the cost of walking to the nut and tightening it.
    6. Sum the total cost of all actions required to tighten all loose nuts.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map locations to their connected locations using "link" relationships.
        self.links = {}
        for fact in static_facts:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.links:
                    self.links[loc1] = set()
                if loc2 not in self.links:
                    self.links[loc2] = set()
                self.links[loc1].add(loc2)
                self.links[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Identify the current location of the man.
        man_location = None
        for fact in state:
            if match(fact, "at", "*", "*"):
                obj, loc = get_parts(fact)
                if obj == "bob":
                    man_location = loc
                    break

        # Identify all loose nuts and their locations.
        loose_nuts = []
        for fact in state:
            if match(fact, "loose", "*"):
                nut = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", nut, "*"):
                        nut_loc = get_parts(loc_fact)[2]
                        loose_nuts.append((nut, nut_loc))
                        break

        # Identify all usable spanners and their locations.
        usable_spanners = []
        for fact in state:
            if match(fact, "usable", "*"):
                spanner = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", spanner, "*"):
                        spanner_loc = get_parts(loc_fact)[2]
                        usable_spanners.append((spanner, spanner_loc))
                        break

        # Check if the man is carrying a usable spanner.
        carrying_spanner = None
        for fact in state:
            if match(fact, "carrying", "*", "*"):
                man, spanner = get_parts(fact)
                if man == "bob":
                    carrying_spanner = spanner
                    break

        # If the man is not carrying a usable spanner, find the nearest one.
        if carrying_spanner is None:
            if not usable_spanners:
                return float("inf")  # No usable spanners available.
            # Compute the distance to the nearest usable spanner.
            min_distance = float("inf")
            nearest_spanner_loc = None
            for spanner, spanner_loc in usable_spanners:
                distance = self._compute_distance(man_location, spanner_loc)
                if distance < min_distance:
                    min_distance = distance
                    nearest_spanner_loc = spanner_loc
            # Add the cost of walking to the spanner and picking it up.
            total_cost = min_distance + 1  # 1 action to pick up the spanner.
            man_location = nearest_spanner_loc  # Update man's location.
        else:
            total_cost = 0

        # For each loose nut, compute the distance and add the cost.
        for nut, nut_loc in loose_nuts:
            distance = self._compute_distance(man_location, nut_loc)
            total_cost += distance + 1  # 1 action to tighten the nut.

        return total_cost

    def _compute_distance(self, start, end):
        """
        Compute the minimum number of walk actions required to move from `start` to `end`.
        Uses a breadth-first search (BFS) approach to find the shortest path.
        """
        if start == end:
            return 0
        visited = set()
        queue = [(start, 0)]
        while queue:
            current, distance = queue.pop(0)
            if current == end:
                return distance
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.links.get(current, []):
                queue.append((neighbor, distance + 1))
        return float("inf")  # No path found.
