from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers in the Miconic domain.
    It considers the current state of the elevator, the passengers' origins and destinations, and the
    relationships between floors (e.g., which floors are above others).

    # Assumptions
    - The elevator can only move between floors that are directly connected by the `above` relationship.
    - Each passenger must be boarded and then served at their destination floor.
    - The heuristic does not need to be admissible, so it can overestimate the number of actions.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, such as the `above` relationships between floors and the `destin` (destination) of each passenger.
    - Build a mapping of floors to their directly reachable floors using the `above` relationships.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current floor of the elevator.
    2. For each passenger:
       - If the passenger is not yet boarded, calculate the distance from the elevator's current floor to the passenger's origin floor.
       - If the passenger is boarded but not yet served, calculate the distance from the elevator's current floor to the passenger's destination floor.
    3. Sum the distances for all passengers to estimate the total number of elevator movements required.
    4. Add the number of `board` and `depart` actions required for each passenger (one each per passenger).
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract `above` relationships to build a graph of floor connections.
        self.above_graph = {}
        for fact in static_facts:
            if match(fact, "above", "*", "*"):
                floor1, floor2 = get_parts(fact)[1:]
                if floor1 not in self.above_graph:
                    self.above_graph[floor1] = set()
                self.above_graph[floor1].add(floor2)

        # Extract destination floors for each passenger.
        self.destin_floors = {}
        for fact in static_facts:
            if match(fact, "destin", "*", "*"):
                passenger, floor = get_parts(fact)[1:]
                self.destin_floors[passenger] = floor

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Find the current floor of the elevator.
        lift_at = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                lift_at = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize the heuristic cost.

        # Iterate over all passengers.
        for passenger in self.destin_floors:
            # Check if the passenger is already served.
            if f"(served {passenger})" in state:
                continue

            # Check if the passenger is boarded.
            if f"(boarded {passenger})" in state:
                # Passenger is boarded but not served. Calculate distance to destination.
                destin_floor = self.destin_floors[passenger]
                distance = self._calculate_distance(lift_at, destin_floor)
                total_cost += distance + 1  # Add 1 for the `depart` action.
            else:
                # Passenger is not boarded. Calculate distance to origin.
                origin_floor = None
                for fact in state:
                    if match(fact, "origin", passenger, "*"):
                        origin_floor = get_parts(fact)[2]
                        break
                if origin_floor:
                    distance = self._calculate_distance(lift_at, origin_floor)
                    total_cost += distance + 1  # Add 1 for the `board` action.

        return total_cost

    def _calculate_distance(self, start_floor, end_floor):
        """
        Calculate the minimum number of elevator movements required to travel from `start_floor` to `end_floor`.

        Args:
            start_floor (str): The starting floor.
            end_floor (str): The destination floor.

        Returns:
            int: The number of elevator movements required.
        """
        if start_floor == end_floor:
            return 0

        # Perform a breadth-first search to find the shortest path.
        visited = set()
        queue = [(start_floor, 0)]

        while queue:
            current_floor, distance = queue.pop(0)
            if current_floor == end_floor:
                return distance

            if current_floor in visited:
                continue
            visited.add(current_floor)

            # Add all directly reachable floors to the queue.
            if current_floor in self.above_graph:
                for next_floor in self.above_graph[current_floor]:
                    queue.append((next_floor, distance + 1))

        # If no path is found, return a large number (indicating an unsolvable state).
        return float('inf')
