from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at rover1 waypoint1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class RoversHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Rovers domain.

    # Summary
    This heuristic estimates the number of actions required to achieve all goals in the Rovers domain.
    It considers the following tasks:
    - Navigating to waypoints to collect soil and rock samples.
    - Calibrating cameras and taking images of objectives.
    - Communicating data to the lander.

    # Assumptions
    - Rovers can only carry one sample at a time (soil or rock).
    - Cameras must be calibrated before taking images.
    - Data communication requires the rover to be at a waypoint visible to the lander.

    # Heuristic Initialization
    - Extract goal conditions and static facts from the task.
    - Build data structures to map waypoints, rovers, cameras, and objectives.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of each rover:
       - Location (waypoint).
       - Whether it has soil or rock samples.
       - Whether it has images of objectives.
    2. For each goal:
       - If the goal is to communicate soil data:
         - Check if the rover has the soil sample.
         - Estimate the number of actions to navigate to the lander and communicate.
       - If the goal is to communicate rock data:
         - Check if the rover has the rock sample.
         - Estimate the number of actions to navigate to the lander and communicate.
       - If the goal is to communicate image data:
         - Check if the rover has the required image.
         - Estimate the number of actions to navigate to the lander and communicate.
    3. For soil and rock samples:
       - If the sample is not yet collected, estimate the number of actions to navigate to the sample location and collect it.
    4. For images:
       - If the image is not yet taken, estimate the number of actions to calibrate the camera, navigate to the objective, and take the image.
    5. Sum the estimated actions for all goals to compute the heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract static information into suitable data structures.
        self.lander_location = None
        self.waypoints = set()
        self.rovers = set()
        self.cameras = set()
        self.objectives = set()
        self.visible_from = {}  # Maps objectives to waypoints where they are visible.
        self.calibration_targets = {}  # Maps cameras to their calibration targets.

        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "at_lander":
                self.lander_location = args[1]  # Landers are always at a waypoint.
            elif predicate == "visible_from":
                objective, waypoint = args
                if objective not in self.visible_from:
                    self.visible_from[objective] = set()
                self.visible_from[objective].add(waypoint)
            elif predicate == "calibration_target":
                camera, objective = args
                self.calibration_targets[camera] = objective

        # Extract all waypoints, rovers, cameras, and objectives from the static facts.
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "waypoint":
                self.waypoints.add(args[0])
            elif predicate == "rover":
                self.rovers.add(args[0])
            elif predicate == "camera":
                self.cameras.add(args[0])
            elif predicate == "objective":
                self.objectives.add(args[0])

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # If the goal is already reached, return 0.
        if self.goals <= state:
            return 0

        total_cost = 0  # Initialize action cost counter.

        # Track which goals are already satisfied.
        satisfied_goals = set()

        # Iterate over all goals to estimate the cost of achieving them.
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "communicated_soil_data":
                waypoint = args[0]
                # Check if the soil data has already been communicated.
                if goal in state:
                    satisfied_goals.add(goal)
                    continue
                # Estimate the cost to collect and communicate the soil data.
                total_cost += self._estimate_soil_data_cost(state, waypoint)
            elif predicate == "communicated_rock_data":
                waypoint = args[0]
                # Check if the rock data has already been communicated.
                if goal in state:
                    satisfied_goals.add(goal)
                    continue
                # Estimate the cost to collect and communicate the rock data.
                total_cost += self._estimate_rock_data_cost(state, waypoint)
            elif predicate == "communicated_image_data":
                objective, mode = args
                # Check if the image data has already been communicated.
                if goal in state:
                    satisfied_goals.add(goal)
                    continue
                # Estimate the cost to take and communicate the image.
                total_cost += self._estimate_image_data_cost(state, objective, mode)

        return total_cost

    def _estimate_soil_data_cost(self, state, waypoint):
        """Estimate the cost to collect and communicate soil data from a waypoint."""
        cost = 0

        # Check if the soil sample has already been collected.
        soil_collected = any(
            match(fact, "have_soil_analysis", "*", waypoint) for fact in state
        )
        if not soil_collected:
            # Estimate the cost to navigate to the waypoint and collect the soil sample.
            cost += 2  # Navigate and sample.

        # Estimate the cost to communicate the soil data.
        cost += 1  # Communicate.

        return cost

    def _estimate_rock_data_cost(self, state, waypoint):
        """Estimate the cost to collect and communicate rock data from a waypoint."""
        cost = 0

        # Check if the rock sample has already been collected.
        rock_collected = any(
            match(fact, "have_rock_analysis", "*", waypoint) for fact in state
        )
        if not rock_collected:
            # Estimate the cost to navigate to the waypoint and collect the rock sample.
            cost += 2  # Navigate and sample.

        # Estimate the cost to communicate the rock data.
        cost += 1  # Communicate.

        return cost

    def _estimate_image_data_cost(self, state, objective, mode):
        """Estimate the cost to take and communicate an image of an objective."""
        cost = 0

        # Check if the image has already been taken.
        image_taken = any(
            match(fact, "have_image", "*", objective, mode) for fact in state
        )
        if not image_taken:
            # Estimate the cost to calibrate the camera, navigate, and take the image.
            cost += 3  # Calibrate, navigate, and take image.

        # Estimate the cost to communicate the image data.
        cost += 1  # Communicate.

        return cost
