from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by considering:
    - The number of passengers still to be served.
    - The current position of the lift.
    - The origin and destination floors of each passenger.

    # Assumptions
    - The lift can move between floors using the `up` and `down` actions.
    - Passengers must be boarded at their origin floor and served at their destination floor.
    - The heuristic does not need to be admissible, so it can overestimate the number of actions.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, such as the `above` relationships between floors.
    - Create a mapping of passengers to their origin and destination floors.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of the lift and the passengers:
        - Check which passengers are already served.
        - Check which passengers are boarded.
        - Check the current floor of the lift.
    2. For each unserved passenger:
        - If the passenger is not boarded, estimate the cost to board them:
            - Move the lift to their origin floor.
            - Board the passenger.
        - Estimate the cost to serve them:
            - Move the lift to their destination floor.
            - Depart the passenger.
    3. Sum the estimated costs for all unserved passengers.
    4. Return the total estimated cost as the heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract passenger origin and destination floors from static facts.
        self.passenger_origins = {}
        self.passenger_destinations = {}
        for fact in static_facts:
            if match(fact, "origin", "*", "*"):
                _, passenger, floor = get_parts(fact)
                self.passenger_origins[passenger] = floor
            elif match(fact, "destin", "*", "*"):
                _, passenger, floor = get_parts(fact)
                self.passenger_destinations[passenger] = floor

        # Extract the `above` relationships between floors.
        self.above_relationships = set()
        for fact in static_facts:
            if match(fact, "above", "*", "*"):
                _, floor1, floor2 = get_parts(fact)
                self.above_relationships.add((floor1, floor2))

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Identify the current floor of the lift.
        lift_at = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                _, floor = get_parts(fact)
                lift_at = floor
                break

        # Identify which passengers are already served.
        served_passengers = set()
        for fact in state:
            if match(fact, "served", "*"):
                _, passenger = get_parts(fact)
                served_passengers.add(passenger)

        # Identify which passengers are boarded.
        boarded_passengers = set()
        for fact in state:
            if match(fact, "boarded", "*"):
                _, passenger = get_parts(fact)
                boarded_passengers.add(passenger)

        total_cost = 0  # Initialize the heuristic cost.

        for passenger in self.passenger_origins:
            if passenger in served_passengers:
                continue  # Passenger is already served.

            origin_floor = self.passenger_origins[passenger]
            destination_floor = self.passenger_destinations[passenger]

            # Cost to board the passenger.
            if passenger not in boarded_passengers:
                # Move the lift to the origin floor.
                total_cost += self._estimate_move_cost(lift_at, origin_floor)
                # Board the passenger.
                total_cost += 1
                lift_at = origin_floor

            # Cost to serve the passenger.
            # Move the lift to the destination floor.
            total_cost += self._estimate_move_cost(lift_at, destination_floor)
            # Depart the passenger.
            total_cost += 1
            lift_at = destination_floor

        return total_cost

    def _estimate_move_cost(self, current_floor, target_floor):
        """
        Estimate the number of actions required to move the lift from `current_floor` to `target_floor`.

        - If the floors are the same, no action is needed.
        - Otherwise, count the number of floors between them using the `above` relationships.
        """
        if current_floor == target_floor:
            return 0

        # Count the number of floors between current_floor and target_floor.
        # This is a simple approximation and can be improved.
        return abs(int(current_floor[1:]) - int(target_floor[1:]))
