from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by:
    - Calculating the number of passengers still needing to be served.
    - Estimating the number of elevator movements required to pick up and drop off passengers.

    # Assumptions
    - The elevator can only move between floors connected by the `above` relation.
    - Each passenger must be picked up from their origin floor and dropped off at their destination floor.
    - The elevator can carry multiple passengers at once, but the heuristic does not account for this optimization.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts (`above` relationships and passenger destinations).

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify passengers who are not yet served.
    2. For each unserved passenger:
       - If the passenger is not yet boarded, estimate the number of elevator movements required to reach their origin floor.
       - Estimate the number of elevator movements required to reach their destination floor from their origin floor.
    3. Sum the total number of elevator movements and boarding/departing actions.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map passengers to their destination floors.
        self.destinations = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in static_facts
            if match(fact, "destin", "*", "*")
        }

        # Extract the `above` relationships to determine floor connectivity.
        self.above_relations = {
            (get_parts(fact)[1], get_parts(fact)[2])
            for fact in static_facts
            if match(fact, "above", "*", "*")
        }

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Identify passengers who are not yet served.
        unserved_passengers = [
            passenger
            for passenger in self.destinations
            if f"(served {passenger})" not in state
        ]

        # If all passengers are served, the heuristic is 0.
        if not unserved_passengers:
            return 0

        # Get the current floor of the elevator.
        current_floor = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                current_floor = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize the heuristic cost.

        for passenger in unserved_passengers:
            # If the passenger is not yet boarded, estimate the cost to pick them up.
            if f"(boarded {passenger})" not in state:
                origin_floor = None
                for fact in state:
                    if match(fact, "origin", passenger, "*"):
                        origin_floor = get_parts(fact)[2]
                        break

                if origin_floor:
                    # Estimate the number of elevator movements to reach the origin floor.
                    total_cost += self._estimate_movements(current_floor, origin_floor)
                    current_floor = origin_floor  # Update the elevator's current floor.
                    total_cost += 1  # Cost to board the passenger.

            # Estimate the number of elevator movements to reach the destination floor.
            destination_floor = self.destinations[passenger]
            total_cost += self._estimate_movements(current_floor, destination_floor)
            current_floor = destination_floor  # Update the elevator's current floor.
            total_cost += 1  # Cost to depart the passenger.

        return total_cost

    def _estimate_movements(self, from_floor, to_floor):
        """
        Estimate the number of elevator movements required to travel between two floors.

        - `from_floor`: The starting floor.
        - `to_floor`: The target floor.
        - Returns the estimated number of movements.
        """
        if from_floor == to_floor:
            return 0

        # Use a simple heuristic: count the number of floors between `from_floor` and `to_floor`.
        # This assumes that the elevator can move directly between any two floors.
        return 1  # Simplified to 1 movement per floor change.
