from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at spanner1 location1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the goal state.
    It considers the following steps:
    - Walking to the location of a spanner.
    - Picking up the spanner.
    - Walking to the location of a loose nut.
    - Tightening the nut.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - A spanner can be used only once (it becomes unusable after tightening a nut).
    - The man must walk to the location of a spanner or nut to interact with it.
    - The goal is to tighten all loose nuts.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts).
    - Extract static facts (links between locations).
    - Identify the locations of all spanners and nuts.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current location of the man.
    2. Identify the locations of all loose nuts.
    3. Identify the locations of all usable spanners.
    4. For each loose nut:
       - If the man is not carrying a usable spanner, estimate the cost to:
         - Walk to the nearest usable spanner.
         - Pick up the spanner.
       - Walk to the nut's location.
       - Tighten the nut.
    5. Sum the estimated costs for all loose nuts.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (links between locations).

        # Extract links between locations.
        self.links = {}
        for fact in self.static:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                self.links.setdefault(loc1, set()).add(loc2)
                self.links.setdefault(loc2, set()).add(loc1)

    def __call__(self, node):
        """Estimate the number of actions required to tighten all loose nuts."""
        state = node.state  # Current world state.

        # Extract the current location of the man.
        man_location = None
        for fact in state:
            if match(fact, "at", "*", "*"):
                obj, loc = get_parts(fact)
                if obj == "bob":  # Assuming the man is named "bob".
                    man_location = loc
                    break
        if not man_location:
            return float("inf")  # Invalid state.

        # Extract locations of loose nuts.
        loose_nuts = set()
        for fact in state:
            if match(fact, "loose", "*"):
                nut = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", nut, "*"):
                        loose_nuts.add((nut, get_parts(loc_fact)[2]))
                        break

        # Extract locations of usable spanners.
        usable_spanners = set()
        for fact in state:
            if match(fact, "usable", "*"):
                spanner = get_parts(fact)[1]
                for loc_fact in state:
                    if match(loc_fact, "at", spanner, "*"):
                        usable_spanners.add((spanner, get_parts(loc_fact)[2]))
                        break

        # Check if the man is carrying a usable spanner.
        carrying_spanner = None
        for fact in state:
            if match(fact, "carrying", "*", "*"):
                man, spanner = get_parts(fact)
                if man == "bob":
                    carrying_spanner = spanner
                    break

        # If carrying a spanner, check if it's usable.
        if carrying_spanner:
            if not any(match(fact, "usable", carrying_spanner) for fact in state):
                carrying_spanner = None  # Spanner is not usable.

        total_cost = 0  # Initialize the heuristic cost.

        for nut, nut_location in loose_nuts:
            if carrying_spanner:
                # Already carrying a usable spanner.
                # Walk to the nut's location and tighten it.
                total_cost += self._walk_cost(man_location, nut_location) + 1
                man_location = nut_location  # Update man's location.
                carrying_spanner = None  # Spanner is used up.
            else:
                # Need to pick up a usable spanner.
                if not usable_spanners:
                    return float("inf")  # No usable spanners left.

                # Find the nearest usable spanner.
                nearest_spanner, spanner_location = min(
                    usable_spanners,
                    key=lambda x: self._walk_cost(man_location, x[1])
                )

                # Walk to the spanner, pick it up, walk to the nut, and tighten it.
                total_cost += (
                    self._walk_cost(man_location, spanner_location) +  # Walk to spanner.
                    1 +  # Pick up spanner.
                    self._walk_cost(spanner_location, nut_location) +  # Walk to nut.
                    1  # Tighten nut.
                )
                man_location = nut_location  # Update man's location.
                usable_spanners.remove((nearest_spanner, spanner_location))  # Spanner is used up.

        return total_cost

    def _walk_cost(self, start, end):
        """Estimate the number of walk actions required to move from `start` to `end`."""
        if start == end:
            return 0

        # Perform a simple BFS to find the shortest path.
        visited = set()
        queue = [(start, 0)]
        while queue:
            current, cost = queue.pop(0)
            if current == end:
                return cost
            if current in visited:
                continue
            visited.add(current)
            for neighbor in self.links.get(current, []):
                queue.append((neighbor, cost + 1))

        return float("inf")  # No path found.
