from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by considering:
    - The number of passengers not yet served.
    - The number of floors the lift must travel to pick up and drop off passengers.
    - The current state of the lift (which floor it is on).

    # Assumptions
    - The lift can only move between floors that are directly connected by the `above` relation.
    - Each passenger must be picked up from their origin floor and dropped off at their destination floor.
    - The lift can carry multiple passengers at once, but each passenger must be boarded and served individually.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, including the `above` relationships between floors and the `origin` and `destin` of each passenger.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of the lift (which floor it is on).
    2. For each passenger not yet served:
        - If the passenger is not yet boarded, calculate the distance from the lift's current floor to the passenger's origin floor.
        - If the passenger is boarded, calculate the distance from the lift's current floor to the passenger's destination floor.
    3. Sum the distances for all passengers to estimate the total number of lift movements required.
    4. Add the number of `board` and `depart` actions required for each passenger.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract `above` relationships between floors.
        self.above = {
            (get_parts(fact)[1], get_parts(fact)[2])
            for fact in static_facts
            if match(fact, "above", "*", "*")
        }

        # Extract `origin` and `destin` for each passenger.
        self.origins = {}
        self.destins = {}
        for fact in static_facts:
            if match(fact, "origin", "*", "*"):
                passenger, floor = get_parts(fact)[1], get_parts(fact)[2]
                self.origins[passenger] = floor
            elif match(fact, "destin", "*", "*"):
                passenger, floor = get_parts(fact)[1], get_parts(fact)[2]
                self.destins[passenger] = floor

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Identify the current floor of the lift.
        lift_at = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                lift_at = get_parts(fact)[1]
                break

        # Initialize the heuristic value.
        total_cost = 0

        # Track passengers who are not yet served.
        unserved_passengers = set(self.destins.keys())
        for fact in state:
            if match(fact, "served", "*"):
                unserved_passengers.discard(get_parts(fact)[1])

        # For each unserved passenger, calculate the required actions.
        for passenger in unserved_passengers:
            # Check if the passenger is already boarded.
            boarded = any(match(fact, "boarded", passenger) for fact in state)

            if boarded:
                # Passenger is boarded; need to go to the destination floor.
                destination = self.destins[passenger]
                total_cost += self._distance(lift_at, destination)
                total_cost += 1  # Depart action.
            else:
                # Passenger is not boarded; need to go to the origin floor.
                origin = self.origins[passenger]
                total_cost += self._distance(lift_at, origin)
                total_cost += 1  # Board action.

        return total_cost

    def _distance(self, floor1, floor2):
        """
        Calculate the minimum number of lift movements required to travel from `floor1` to `floor2`.
        """
        if floor1 == floor2:
            return 0

        # Use BFS to find the shortest path between floors.
        from collections import deque

        queue = deque([(floor1, 0)])
        visited = set()

        while queue:
            current_floor, steps = queue.popleft()
            if current_floor == floor2:
                return steps

            if current_floor in visited:
                continue
            visited.add(current_floor)

            # Add adjacent floors (both above and below).
            for f1, f2 in self.above:
                if f1 == current_floor:
                    queue.append((f2, steps + 1))
                if f2 == current_floor:
                    queue.append((f1, steps + 1))

        # If no path is found, return a large number (indicating an unsolvable state).
        return float('inf')
