from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by considering:
    - The number of passengers not yet served.
    - The number of passengers not yet boarded.
    - The distance the elevator must travel to pick up and drop off passengers.

    # Assumptions:
    - The elevator can only move between floors in a single direction (up or down) at a time.
    - The elevator must stop at each floor to pick up or drop off passengers.
    - The heuristic does not account for the optimal order of serving passengers.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, such as the `above` relationships between floors, to compute distances.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of the elevator (which floor it is on).
    2. Identify the passengers who are not yet served.
    3. For each unserved passenger:
        - If the passenger is not yet boarded, estimate the distance from the elevator's current position to the passenger's origin floor.
        - Estimate the distance from the passenger's origin floor to their destination floor.
    4. Sum the distances and add the number of boarding and departing actions required.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract the `above` relationships to compute distances between floors.
        self.above_relations = {
            (get_parts(fact)[1], get_parts(fact)[2]): True
            for fact in static_facts
            if match(fact, "above", "*", "*")
        }

        # Store the goal conditions for each passenger.
        self.goal_passengers = {
            get_parts(goal)[1] for goal in self.goals if match(goal, "served", "*")
        }

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Identify the current floor of the elevator.
        lift_at = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                lift_at = get_parts(fact)[1]
                break

        # Identify passengers who are not yet served.
        unserved_passengers = self.goal_passengers - {
            get_parts(fact)[1] for fact in state if match(fact, "served", "*")
        }

        total_cost = 0  # Initialize the heuristic cost.

        for passenger in unserved_passengers:
            # Check if the passenger is already boarded.
            boarded = any(match(fact, "boarded", passenger) for fact in state)

            # Find the passenger's origin and destination floors.
            origin_floor = None
            destin_floor = None
            for fact in state:
                if match(fact, "origin", passenger, "*"):
                    origin_floor = get_parts(fact)[2]
                if match(fact, "destin", passenger, "*"):
                    destin_floor = get_parts(fact)[2]

            # Estimate the distance from the elevator's current position to the origin floor.
            if not boarded:
                # Compute the number of floors between the elevator and the origin floor.
                distance = self._compute_distance(lift_at, origin_floor)
                total_cost += distance + 1  # Add 1 for the boarding action.

            # Estimate the distance from the origin floor to the destination floor.
            distance = self._compute_distance(origin_floor, destin_floor)
            total_cost += distance + 1  # Add 1 for the departing action.

        return total_cost

    def _compute_distance(self, floor1, floor2):
        """Compute the number of floors between two floors using the `above` relationships."""
        if floor1 == floor2:
            return 0

        # Determine the direction of travel.
        if (floor1, floor2) in self.above_relations:
            return 1  # Directly above.
        elif (floor2, floor1) in self.above_relations:
            return 1  # Directly below.
        else:
            # If not directly connected, assume the worst-case scenario (maximum distance).
            return len(self.above_relations)  # Overestimate the distance.
