from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(origin p1 f1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class MiconicHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Miconic domain.

    # Summary
    This heuristic estimates the number of actions required to serve all passengers by considering:
    - The current location of the lift.
    - The origin and destination floors of each passenger.
    - Whether passengers are already boarded or served.

    # Assumptions
    - The lift can move between floors using the `up` and `down` actions.
    - Passengers must be boarded before they can be served.
    - The heuristic assumes that the lift can serve passengers in an optimal order.

    # Heuristic Initialization
    - Extract the goal conditions (all passengers must be served).
    - Extract static facts, such as the `above` relationships between floors and the `destin` of each passenger.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of the lift and passengers:
       - Check the current floor of the lift (`lift-at`).
       - Identify which passengers are already boarded or served.
    2. For each passenger not yet served:
       - If the passenger is not boarded, estimate the cost to board them:
         - Move the lift to their origin floor.
         - Perform the `board` action.
       - If the passenger is boarded, estimate the cost to serve them:
         - Move the lift to their destination floor.
         - Perform the `depart` action.
    3. Sum the estimated actions for all passengers to compute the heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract the destination floors for each passenger.
        self.destin_floors = {
            get_parts(fact)[1]: get_parts(fact)[2]
            for fact in static_facts
            if match(fact, "destin", "*", "*")
        }

        # Extract the `above` relationships between floors.
        self.above_relationships = {
            (get_parts(fact)[1], get_parts(fact)[2])
            for fact in static_facts
            if match(fact, "above", "*", "*")
        }

    def __call__(self, node):
        """Estimate the number of actions required to serve all passengers."""
        state = node.state  # Current world state.

        # Identify the current floor of the lift.
        lift_at = None
        for fact in state:
            if match(fact, "lift-at", "*"):
                lift_at = get_parts(fact)[1]
                break

        # Identify passengers who are already served.
        served_passengers = {
            get_parts(fact)[1] for fact in state if match(fact, "served", "*")
        }

        # Identify passengers who are already boarded.
        boarded_passengers = {
            get_parts(fact)[1] for fact in state if match(fact, "boarded", "*")
        }

        total_cost = 0  # Initialize the heuristic cost.

        for passenger, destin_floor in self.destin_floors.items():
            if passenger in served_passengers:
                continue  # Passenger is already served.

            # Check if the passenger is boarded.
            if passenger in boarded_passengers:
                # Passenger is boarded but not served.
                # Estimate the cost to move the lift to the destination floor and perform `depart`.
                total_cost += self._estimate_move_cost(lift_at, destin_floor) + 1
            else:
                # Passenger is not boarded.
                # Estimate the cost to move the lift to the origin floor, perform `board`, then move to the destination floor, and perform `depart`.
                origin_floor = None
                for fact in state:
                    if match(fact, "origin", passenger, "*"):
                        origin_floor = get_parts(fact)[2]
                        break
                if origin_floor:
                    total_cost += (
                        self._estimate_move_cost(lift_at, origin_floor)
                        + 1  # Board action.
                        + self._estimate_move_cost(origin_floor, destin_floor)
                        + 1  # Depart action.
                    )

        return total_cost

    def _estimate_move_cost(self, from_floor, to_floor):
        """
        Estimate the number of `up` or `down` actions required to move the lift from `from_floor` to `to_floor`.

        - `from_floor`: The current floor of the lift.
        - `to_floor`: The target floor.
        - Returns the number of actions required.
        """
        if from_floor == to_floor:
            return 0

        # Determine the direction of movement.
        if (from_floor, to_floor) in self.above_relationships:
            return 1  # One `up` action.
        elif (to_floor, from_floor) in self.above_relationships:
            return 1  # One `down` action.
        else:
            # If floors are not directly connected, assume a worst-case scenario.
            return abs(int(from_floor[1:]) - int(to_floor[1:]))
