from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by considering:
    - The number of passengers not yet served.
    - The number of passengers not yet boarded.
    - The distance the elevator must travel to pick up and drop off passengers.

    # Assumptions
    - The elevator can only move between floors that are directly connected by the "above" relation.
    - Each passenger must be picked up from their origin floor and dropped off at their destination floor.
    - The elevator can carry multiple passengers at once, but the heuristic does not account for this optimization.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract the static facts, including the "above" relationships between floors and the destination floors for each passenger.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of the elevator (which floor it is on).
    2. Identify the passengers who are not yet served.
    3. For each unserved passenger:
       - If the passenger is not yet boarded, calculate the distance from the elevator's current position to the passenger's origin floor.
       - Calculate the distance from the passenger's origin floor to their destination floor.
    4. Sum the distances for all unserved passengers to estimate the total number of elevator movements required.
    5. Add the number of unserved passengers to account for the boarding and departing actions.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map passengers to their destination floors.
        self.destinations = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in static_facts
            if match(fact, "destin", "*", "*")
        }

        # Extract the "above" relationships between floors.
        self.above_relations = {
            (get_parts(fact)[1], get_parts(fact)[2])
            for fact in static_facts
            if match(fact, "above", "*", "*")
        }

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Identify the current floor of the elevator.
        elevator_floor = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                elevator_floor = get_parts(fact)[1]
                break

        # Identify passengers who are not yet served.
        unserved_passengers = [
            passenger
            for passenger in self.destinations
            if f"(served {passenger})" not in state
        ]

        total_cost = 0  # Initialize the heuristic cost.

        for passenger in unserved_passengers:
            # If the passenger is not yet boarded, calculate the distance to their origin floor.
            if f"(boarded {passenger})" not in state:
                origin_floor = None
                for fact in state:
                    if match(fact, "origin", passenger, "*"):
                        origin_floor = get_parts(fact)[2]
                        break
                if origin_floor:
                    # Calculate the distance from the elevator's current position to the origin floor.
                    total_cost += self._calculate_distance(elevator_floor, origin_floor)

            # Calculate the distance from the origin floor to the destination floor.
            destination_floor = self.destinations[passenger]
            total_cost += self._calculate_distance(origin_floor or elevator_floor, destination_floor)

        # Add the number of unserved passengers to account for boarding and departing actions.
        total_cost += len(unserved_passengers)

        return total_cost

    def _calculate_distance(self, floor1, floor2):
        """
        Calculate the minimum number of elevator movements required to travel from `floor1` to `floor2`.

        - `floor1`: The starting floor.
        - `floor2`: The destination floor.
        - Returns the number of floors the elevator must traverse.
        """
        if floor1 == floor2:
            return 0

        # Use BFS to find the shortest path between floors.
        from collections import deque

        queue = deque([(floor1, 0)])
        visited = set([floor1])

        while queue:
            current_floor, distance = queue.popleft()
            if current_floor == floor2:
                return distance

            # Explore all floors directly above or below the current floor.
            for (f1, f2) in self.above_relations:
                if f1 == current_floor and f2 not in visited:
                    visited.add(f2)
                    queue.append((f2, distance + 1))
                if f2 == current_floor and f1 not in visited:
                    visited.add(f1)
                    queue.append((f1, distance + 1))

        # If no path is found, return a large number (indicating an unsolvable state).
        return float('inf')
