from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by considering:
    - The number of passengers still to be served.
    - The distance the elevator must travel to pick up and drop off passengers.
    - The current state of the elevator (which floor it is on).

    # Assumptions
    - The elevator can only move between floors that are directly connected by the "above" relation.
    - Each passenger must be picked up from their origin floor and dropped off at their destination floor.
    - The elevator can carry multiple passengers at once, but each boarding and departing action is counted separately.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, including the "above" relationships between floors and the destination floors for each passenger.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current floor of the elevator.
    2. For each passenger not yet served:
        a. If the passenger is not boarded, calculate the distance from the elevator's current floor to the passenger's origin floor.
        b. If the passenger is boarded, calculate the distance from the elevator's current floor to the passenger's destination floor.
    3. Sum the distances for all passengers, weighted by the number of actions required (e.g., moving between floors, boarding, and departing).
    4. Add a penalty for each passenger not yet served to account for the boarding and departing actions.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract "above" relationships between floors.
        self.above = {}
        for fact in static_facts:
            if match(fact, "above", "*", "*"):
                parts = get_parts(fact)
                floor1, floor2 = parts[1], parts[2]
                if floor1 not in self.above:
                    self.above[floor1] = set()
                self.above[floor1].add(floor2)

        # Extract destination floors for each passenger.
        self.destinations = {}
        for fact in static_facts:
            if match(fact, "destin", "*", "*"):
                parts = get_parts(fact)
                passenger, floor = parts[1], parts[2]
                self.destinations[passenger] = floor

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Identify the current floor of the elevator.
        elevator_floor = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                elevator_floor = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize action cost counter.

        # Track which passengers are already served.
        served_passengers = set()
        for fact in state:
            if match(fact, "served", "*"):
                served_passengers.add(get_parts(fact)[1])

        # Track which passengers are currently boarded.
        boarded_passengers = set()
        for fact in state:
            if match(fact, "boarded", "*"):
                boarded_passengers.add(get_parts(fact)[1])

        # For each passenger not yet served, calculate the cost.
        for passenger, destination in self.destinations.items():
            if passenger in served_passengers:
                continue  # Passenger already served, no cost.

            # Determine the passenger's origin floor.
            origin_floor = None
            for fact in state:
                if match(fact, "origin", passenger, "*"):
                    origin_floor = get_parts(fact)[2]
                    break

            if passenger in boarded_passengers:
                # Passenger is boarded; calculate distance to destination.
                distance = self._calculate_distance(elevator_floor, destination)
                total_cost += distance + 1  # Add 1 for the depart action.
            else:
                # Passenger is not boarded; calculate distance to origin.
                distance = self._calculate_distance(elevator_floor, origin_floor)
                total_cost += distance + 1  # Add 1 for the board action.

        return total_cost

    def _calculate_distance(self, floor1, floor2):
        """
        Calculate the number of floors between `floor1` and `floor2` using the "above" relationship.
        """
        if floor1 == floor2:
            return 0

        # Perform a breadth-first search to find the shortest path.
        visited = set()
        queue = [(floor1, 0)]
        while queue:
            current_floor, distance = queue.pop(0)
            if current_floor == floor2:
                return distance
            visited.add(current_floor)
            for neighbor in self.above.get(current_floor, []):
                if neighbor not in visited:
                    queue.append((neighbor, distance + 1))

        # If no path is found, return a large number (indicating an unsolvable state).
        return float('inf')
