from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at bob shed)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class SpannerHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Spanner domain.

    # Summary
    This heuristic estimates the number of actions needed to tighten all loose nuts in the Spanner domain.
    It considers the following steps:
    - Walking to the location of a spanner.
    - Picking up the spanner.
    - Walking to the location of a loose nut.
    - Tightening the nut.

    # Assumptions:
    - The man can carry only one spanner at a time.
    - Each spanner can be used only once.
    - The man must walk to the location of a spanner or nut to interact with it.
    - The goal is to tighten all loose nuts.

    # Heuristic Initialization
    - Extract the goal conditions (tightened nuts) and static facts (links between locations).
    - Build a graph of locations using the static `link` facts to compute distances.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify all loose nuts that need to be tightened.
    2. For each loose nut:
       - If the man is not at the nut's location, estimate the walking distance to reach it.
       - If the man is not carrying a usable spanner, estimate the walking distance to the nearest spanner and the cost of picking it up.
    3. Sum the estimated actions for all loose nuts, considering the man's current state (location and carried spanner).
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        self.static = task.static  # Static facts (e.g., links between locations).

        # Build a graph of locations using the static `link` facts.
        self.location_graph = {}
        for fact in self.static:
            if match(fact, "link", "*", "*"):
                _, loc1, loc2 = get_parts(fact)
                if loc1 not in self.location_graph:
                    self.location_graph[loc1] = set()
                if loc2 not in self.location_graph:
                    self.location_graph[loc2] = set()
                self.location_graph[loc1].add(loc2)
                self.location_graph[loc2].add(loc1)

    def __call__(self, node):
        """Estimate the minimum cost to tighten all loose nuts."""
        state = node.state

        # Identify all loose nuts that need to be tightened.
        loose_nuts = set()
        for fact in state:
            if match(fact, "loose", "*"):
                _, nut = get_parts(fact)
                loose_nuts.add(nut)

        # If no loose nuts, the heuristic is 0 (goal state).
        if not loose_nuts:
            return 0

        # Extract the man's current location and carried spanner.
        man_location = None
        carried_spanner = None
        for fact in state:
            if match(fact, "at", "*", "*"):
                _, obj, loc = get_parts(fact)
                if obj == "bob":  # Assuming the man is named "bob".
                    man_location = loc
            if match(fact, "carrying", "*", "*"):
                _, man, spanner = get_parts(fact)
                if man == "bob":
                    carried_spanner = spanner

        # Check if the carried spanner is usable.
        usable_spanner = False
        if carried_spanner:
            for fact in state:
                if match(fact, "usable", "*"):
                    _, spanner = get_parts(fact)
                    if spanner == carried_spanner:
                        usable_spanner = True
                        break

        total_cost = 0

        # For each loose nut, estimate the cost to tighten it.
        for nut in loose_nuts:
            # Find the nut's location.
            nut_location = None
            for fact in state:
                if match(fact, "at", nut, "*"):
                    _, _, loc = get_parts(fact)
                    nut_location = loc
                    break

            # If the man is not at the nut's location, add walking cost.
            if man_location != nut_location:
                # Estimate the walking distance between man_location and nut_location.
                # For simplicity, assume 1 action per link (no pathfinding).
                total_cost += 1

            # If the man is not carrying a usable spanner, add cost to get one.
            if not usable_spanner:
                # Find the nearest spanner.
                nearest_spanner_location = None
                for fact in state:
                    if match(fact, "at", "spanner*", "*"):
                        _, spanner, loc = get_parts(fact)
                        if match(fact, "usable", spanner):
                            nearest_spanner_location = loc
                            break

                # Add walking cost to the spanner and picking it up.
                if nearest_spanner_location:
                    total_cost += 1  # Walking to the spanner.
                    total_cost += 1  # Picking up the spanner.

            # Add cost to tighten the nut.
            total_cost += 1

        return total_cost
