from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by:
    - Moving the lift to the origin floors of passengers who are not yet boarded.
    - Boarding passengers who are at the lift's current floor.
    - Moving the lift to the destination floors of boarded passengers.
    - Departing passengers who are at their destination floors.

    # Assumptions
    - The lift can only move between floors that are directly connected by the `above` relation.
    - Passengers must be boarded before they can be served.
    - The heuristic assumes that the lift can move directly between any two floors, even if multiple steps are required.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, including the `above` relationships between floors and the destination floors of passengers.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify passengers who are not yet served.
    2. For each unserved passenger:
       - If the passenger is not boarded, estimate the cost to move the lift to their origin floor and board them.
       - If the passenger is boarded, estimate the cost to move the lift to their destination floor and depart them.
    3. Sum the estimated costs for all passengers to compute the total heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract the `above` relationships between floors.
        self.above_relations = {
            (get_parts(fact)[1], get_parts(fact)[2])
            for fact in static_facts
            if match(fact, "above", "*", "*")
        }

        # Extract the destination floors for each passenger.
        self.destinations = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in static_facts
            if match(fact, "destin", "*", "*")
        }

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Check if the goal is already reached.
        if self.goals <= state:
            return 0

        # Extract the current lift location.
        lift_location = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                lift_location = get_parts(fact)[1]
                break
        assert lift_location is not None, "Lift location not found in state."

        total_cost = 0  # Initialize the heuristic cost.

        # Iterate over all passengers to estimate the cost to serve them.
        for fact in state:
            if match(fact, "origin", "*", "*"):
                passenger, origin_floor = get_parts(fact)[1], get_parts(fact)[2]

                # Check if the passenger is already served.
                if f"(served {passenger})" in state:
                    continue

                # Check if the passenger is boarded.
                if f"(boarded {passenger})" in state:
                    # Passenger is boarded; estimate the cost to move to their destination.
                    destination_floor = self.destinations[passenger]
                    if lift_location != destination_floor:
                        total_cost += 1  # Move the lift to the destination floor.
                    total_cost += 1  # Depart the passenger.
                else:
                    # Passenger is not boarded; estimate the cost to move to their origin and board them.
                    if lift_location != origin_floor:
                        total_cost += 1  # Move the lift to the origin floor.
                    total_cost += 1  # Board the passenger.

        return total_cost
