from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers in the Miconic domain.
    It considers the current state of the elevator, the passengers' origins and destinations, and the
    relationships between floors (e.g., which floors are above others).

    # Assumptions
    - The elevator can only move between floors that are directly connected via the `above` relationship.
    - Passengers must be boarded before they can be served.
    - The elevator must be at the correct floor to board or serve a passenger.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, such as the `above` relationships between floors and the destinations of passengers.
    - Build a mapping of passengers to their destinations and origins.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current floor of the elevator.
    2. For each passenger:
       - If the passenger is already served, no actions are needed.
       - If the passenger is boarded, the elevator must move to their destination floor and serve them.
       - If the passenger is not boarded, the elevator must move to their origin floor, board them, then move to their destination floor and serve them.
    3. Calculate the number of floor transitions required to reach the necessary floors for boarding and serving.
    4. Sum the actions required for all passengers to estimate the total cost.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map passengers to their destinations using "destin" relationships.
        self.passenger_destinations = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in static_facts
            if match(fact, "destin", "*", "*")
        }

        # Map passengers to their origins using "origin" relationships.
        self.passenger_origins = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in static_facts
            if match(fact, "origin", "*", "*")
        }

        # Build a graph of floor relationships using "above" facts.
        self.floor_graph = {}
        for fact in static_facts:
            if match(fact, "above", "*", "*"):
                floor1, floor2 = get_parts(fact)[1], get_parts(fact)[2]
                if floor1 not in self.floor_graph:
                    self.floor_graph[floor1] = set()
                self.floor_graph[floor1].add(floor2)

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Identify the current floor of the elevator.
        elevator_floor = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                elevator_floor = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize action cost counter.

        for passenger in self.passenger_destinations:
            # Check if the passenger is already served.
            if f"(served {passenger})" in state:
                continue

            # Check if the passenger is boarded.
            boarded = f"(boarded {passenger})" in state

            # Get the passenger's origin and destination floors.
            origin_floor = self.passenger_origins[passenger]
            destination_floor = self.passenger_destinations[passenger]

            if boarded:
                # If boarded, the elevator must move to the destination floor and serve the passenger.
                total_cost += self._calculate_floor_transitions(elevator_floor, destination_floor)
                total_cost += 1  # Serve action.
            else:
                # If not boarded, the elevator must move to the origin floor, board the passenger,
                # then move to the destination floor and serve them.
                total_cost += self._calculate_floor_transitions(elevator_floor, origin_floor)
                total_cost += 1  # Board action.
                total_cost += self._calculate_floor_transitions(origin_floor, destination_floor)
                total_cost += 1  # Serve action.

        return total_cost

    def _calculate_floor_transitions(self, start_floor, end_floor):
        """
        Calculate the number of floor transitions required to move from `start_floor` to `end_floor`.

        - `start_floor`: The starting floor.
        - `end_floor`: The target floor.
        - Returns the number of transitions (up or down actions) required.
        """
        if start_floor == end_floor:
            return 0

        # Perform a breadth-first search to find the shortest path between floors.
        visited = set()
        queue = [(start_floor, 0)]

        while queue:
            current_floor, steps = queue.pop(0)
            if current_floor == end_floor:
                return steps

            if current_floor in visited:
                continue
            visited.add(current_floor)

            # Add adjacent floors to the queue.
            if current_floor in self.floor_graph:
                for adjacent_floor in self.floor_graph[current_floor]:
                    queue.append((adjacent_floor, steps + 1))

        # If no path is found, return a large number (indicating an unsolvable state).
        return float("inf")
