from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by considering:
    - The current position of the lift.
    - The origin and destination floors of each passenger.
    - Whether passengers are already boarded or served.

    # Assumptions
    - The lift can move between floors using the `up` and `down` actions.
    - Passengers must be boarded before they can be served.
    - The heuristic assumes that the lift can serve passengers in an optimal order.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts (e.g., `above` relationships between floors).
    - Map each passenger to their origin and destination floors.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of the lift and passengers:
       - Check the current floor of the lift (`lift-at`).
       - Identify which passengers are already boarded or served.
    2. For each passenger not yet served:
       - If the passenger is not boarded, estimate the cost to board them:
         - Move the lift to their origin floor.
         - Perform the `board` action.
       - If the passenger is boarded, estimate the cost to serve them:
         - Move the lift to their destination floor.
         - Perform the `depart` action.
    3. Sum the estimated actions for all passengers to compute the heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map passengers to their origin and destination floors.
        self.origin = {}
        self.destin = {}
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "origin":
                passenger, floor = args
                self.origin[passenger] = floor
            elif predicate == "destin":
                passenger, floor = args
                self.destin[passenger] = floor

        # Extract the `above` relationships between floors.
        self.above = set()
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "above":
                floor1, floor2 = args
                self.above.add((floor1, floor2))

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Identify the current floor of the lift.
        lift_at = None
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "lift-at":
                lift_at = args[0]
                break

        # Track which passengers are already boarded or served.
        boarded = set()
        served = set()
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "boarded":
                boarded.add(args[0])
            elif predicate == "served":
                served.add(args[0])

        total_cost = 0  # Initialize action cost counter.

        for passenger in self.origin:
            if passenger in served:
                continue  # Passenger is already served.

            origin_floor = self.origin[passenger]
            destin_floor = self.destin[passenger]

            if passenger not in boarded:
                # Passenger needs to be boarded.
                # Cost to move lift to origin floor.
                total_cost += self._estimate_move_cost(lift_at, origin_floor)
                lift_at = origin_floor  # Update lift position.
                # Cost to board passenger.
                total_cost += 1  # `board` action.
                boarded.add(passenger)

            # Passenger is now boarded; needs to be served.
            # Cost to move lift to destination floor.
            total_cost += self._estimate_move_cost(lift_at, destin_floor)
            lift_at = destin_floor  # Update lift position.
            # Cost to serve passenger.
            total_cost += 1  # `depart` action.

        return total_cost

    def _estimate_move_cost(self, current_floor, target_floor):
        """
        Estimate the number of `up` or `down` actions required to move the lift
        from `current_floor` to `target_floor`.

        - `current_floor`: The current floor of the lift.
        - `target_floor`: The target floor to move to.
        - Returns the number of actions required.
        """
        if current_floor == target_floor:
            return 0  # No movement needed.

        # Determine the direction of movement.
        if (current_floor, target_floor) in self.above:
            return 1  # One `down` action.
        else:
            return 1  # One `up` action.
