from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers in the Miconic domain.
    It considers the current state of the elevator (lift-at), the passengers' origins and destinations,
    and whether passengers are already boarded or served.

    # Assumptions:
    - The elevator can move between floors using the `up` and `down` actions.
    - Passengers must be boarded before they can be served.
    - The heuristic assumes that the elevator can serve passengers in an optimal order, minimizing the number of actions.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, such as the `above` relationships between floors and the `destin` (destination) of each passenger.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current floor of the elevator (`lift-at`).
    2. For each passenger:
       - If the passenger is already served, no actions are needed.
       - If the passenger is boarded, the elevator must move to their destination floor and perform a `depart` action.
       - If the passenger is not boarded, the elevator must move to their origin floor, perform a `board` action, then move to their destination floor and perform a `depart` action.
    3. Sum the number of actions required for all passengers, considering the optimal order of serving them.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map passengers to their destinations using "destin" relationships.
        self.passenger_destinations = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in static_facts
            if match(fact, "destin", "*", "*")
        }

        # Extract the "above" relationships between floors.
        self.above_relationships = {
            (get_parts(fact)[1], get_parts(fact)[2])
            for fact in static_facts
            if match(fact, "above", "*", "*")
        }

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Extract the current floor of the elevator.
        current_floor = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                current_floor = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize action cost counter.

        for passenger, destination in self.passenger_destinations.items():
            # Check if the passenger is already served.
            if f"(served {passenger})" in state:
                continue

            # Check if the passenger is boarded.
            if f"(boarded {passenger})" in state:
                # The elevator must move to the destination floor and perform a `depart` action.
                total_cost += self._compute_floor_distance(current_floor, destination) + 1
                current_floor = destination
            else:
                # The passenger is not boarded. The elevator must:
                # 1. Move to the origin floor.
                # 2. Perform a `board` action.
                # 3. Move to the destination floor.
                # 4. Perform a `depart` action.
                origin = None
                for fact in state:
                    if match(fact, "origin", passenger, "*"):
                        origin = get_parts(fact)[2]
                        break

                if origin:
                    total_cost += self._compute_floor_distance(current_floor, origin) + 1  # Move and board.
                    total_cost += self._compute_floor_distance(origin, destination) + 1  # Move and depart.
                    current_floor = destination

        return total_cost

    def _compute_floor_distance(self, floor1, floor2):
        """
        Compute the number of floors between `floor1` and `floor2` using the `above` relationships.
        """
        if floor1 == floor2:
            return 0

        # Determine the direction of movement.
        if (floor1, floor2) in self.above_relationships:
            return 1  # Directly above.
        elif (floor2, floor1) in self.above_relationships:
            return 1  # Directly below.
        else:
            # Count the number of floors between them.
            # This assumes a linear ordering of floors.
            floors = sorted({f1 for f1, f2 in self.above_relationships} | {f2 for f1, f2 in self.above_relationships})
            idx1 = floors.index(floor1)
            idx2 = floors.index(floor2)
            return abs(idx1 - idx2)
