from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by calculating the necessary elevator movements and boarding/departing actions.

    # Assumptions:
    - The elevator can only move between floors that are directly connected by the `above` relation.
    - Each passenger must be boarded and then departed at their destination floor.
    - The elevator can carry multiple passengers at once, but each boarding and departing action is counted separately.

    # Heuristic Initialization
    - Extract the `above` relations to determine the floor hierarchy.
    - Extract the goal conditions to identify which passengers need to be served.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current floor of the elevator using the `lift-at` fact.
    2. For each passenger:
       - If the passenger is not yet served:
         - If the passenger is not boarded, calculate the distance from the elevator's current floor to the passenger's origin floor.
         - Add the distance from the origin floor to the destination floor.
         - Add 1 action for boarding and 1 action for departing.
    3. Sum the total number of actions required for all passengers.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract the `above` relations to determine the floor hierarchy.
        self.above_relations = {
            (get_parts(fact)[1], get_parts(fact)[2] for fact in static_facts if match(fact, "above", "*", "*")
        }

        # Store goal locations for each passenger.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "served":
                passenger = args[0]
                self.goal_locations[passenger] = True

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track the current floor of the elevator.
        current_floor = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                current_floor = get_parts(fact)[1]
                break

        total_cost = 0  # Initialize action cost counter.

        for passenger in self.goal_locations:
            # Check if the passenger is already served.
            if f"(served {passenger})" in state:
                continue

            # Extract the origin and destination floors for the passenger.
            origin_floor = None
            destination_floor = None
            for fact in state:
                if match(fact, "origin", passenger, "*"):
                    origin_floor = get_parts(fact)[2]
                if match(fact, "destin", passenger, "*"):
                    destination_floor = get_parts(fact)[2]

            # Calculate the distance from the elevator's current floor to the origin floor.
            if current_floor != origin_floor:
                total_cost += self.calculate_distance(current_floor, origin_floor)

            # Calculate the distance from the origin floor to the destination floor.
            total_cost += self.calculate_distance(origin_floor, destination_floor)

            # Add 1 action for boarding and 1 action for departing.
            total_cost += 2

        return total_cost

    def calculate_distance(self, floor1, floor2):
        """
        Calculate the number of floors between `floor1` and `floor2` using the `above` relations.

        - `floor1`: The starting floor.
        - `floor2`: The target floor.
        - Returns the number of floors between `floor1` and `floor2`.
        """
        if floor1 == floor2:
            return 0

        # Determine the direction of movement.
        if (floor1, floor2) in self.above_relations:
            return 1
        elif (floor2, floor1) in self.above_relations:
            return 1
        else:
            # If floors are not directly connected, assume the worst-case scenario.
            return abs(int(floor1[1:]) - int(floor2[1:]))
