from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers in the Miconic domain.
    It considers the current state of the elevator, the passengers' origins and destinations, and the
    relationships between floors (e.g., which floors are above others).

    # Assumptions
    - The elevator can move between floors in a single action (up or down).
    - Each passenger must be boarded and then served at their destination floor.
    - The heuristic assumes that the elevator can optimize its path to minimize the number of actions.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, such as the `above` relationships between floors and the `origin` and `destin` of each passenger.
    - Build a mapping of passengers to their origin and destination floors.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current floor of the elevator.
    2. For each passenger:
       - If the passenger is not yet boarded, calculate the distance from the elevator's current floor to the passenger's origin floor.
       - If the passenger is boarded but not served, calculate the distance from the elevator's current floor to the passenger's destination floor.
    3. Sum the distances for all passengers, weighted by the number of actions required (e.g., boarding, moving, and serving).
    4. Return the total estimated number of actions.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract passenger origins and destinations from static facts.
        self.passenger_origins = {}
        self.passenger_destinations = {}
        for fact in static_facts:
            parts = get_parts(fact)
            if parts[0] == "origin":
                passenger, floor = parts[1], parts[2]
                self.passenger_origins[passenger] = floor
            elif parts[0] == "destin":
                passenger, floor = parts[1], parts[2]
                self.passenger_destinations[passenger] = floor

        # Extract floor relationships from static facts.
        self.above_relationships = set()
        for fact in static_facts:
            parts = get_parts(fact)
            if parts[0] == "above":
                floor1, floor2 = parts[1], parts[2]
                self.above_relationships.add((floor1, floor2))

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Check if the goal is already reached.
        if self.goals <= state:
            return 0

        # Find the current floor of the elevator.
        elevator_floor = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                elevator_floor = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize action cost counter.

        for passenger in self.passenger_origins:
            # Check if the passenger is already served.
            if f"(served {passenger})" in state:
                continue

            # Check if the passenger is boarded.
            if f"(boarded {passenger})" in state:
                # Passenger is boarded but not served; need to move to destination.
                destination = self.passenger_destinations[passenger]
                total_cost += self._distance(elevator_floor, destination) + 1  # Depart action.
            else:
                # Passenger is not boarded; need to move to origin and board.
                origin = self.passenger_origins[passenger]
                total_cost += self._distance(elevator_floor, origin) + 1  # Board action.
                # Then move to destination and depart.
                destination = self.passenger_destinations[passenger]
                total_cost += self._distance(origin, destination) + 1  # Depart action.

        return total_cost

    def _distance(self, floor1, floor2):
        """
        Compute the number of floors between `floor1` and `floor2` based on the `above` relationships.
        """
        if floor1 == floor2:
            return 0

        # Build a graph of floor relationships.
        graph = {}
        for f1, f2 in self.above_relationships:
            if f1 not in graph:
                graph[f1] = []
            graph[f1].append(f2)

        # Perform a BFS to find the shortest path between the two floors.
        from collections import deque
        queue = deque([(floor1, 0)])
        visited = set()

        while queue:
            current_floor, distance = queue.popleft()
            if current_floor == floor2:
                return distance
            if current_floor in visited:
                continue
            visited.add(current_floor)
            for neighbor in graph.get(current_floor, []):
                queue.append((neighbor, distance + 1))

        # If no path is found, return a large number (indicating an unsolvable state).
        return float('inf')
