from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at rover1 waypoint1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class RoversHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Rovers domain.

    # Summary
    This heuristic estimates the number of actions required to achieve the goal state by considering:
    - The number of soil and rock samples that need to be collected and communicated.
    - The number of images that need to be taken and communicated.
    - The distance rovers need to travel to achieve these tasks.

    # Assumptions
    - Each rover can carry only one soil or rock sample at a time.
    - Each rover can take images only if its camera is calibrated.
    - Communication of data requires the rover to be at a waypoint visible from the lander.

    # Heuristic Initialization
    - Extract goal conditions and static facts from the task.
    - Build data structures to map waypoints, rovers, cameras, and objectives.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the number of soil and rock samples that need to be collected and communicated.
    2. Identify the number of images that need to be taken and communicated.
    3. For each rover, calculate the distance to the nearest waypoint where it can perform its tasks.
    4. Sum the actions required for navigation, sampling, imaging, calibration, and communication.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals
        self.static = task.static

        # Extract static information
        self.waypoints = set()
        self.rovers = set()
        self.cameras = set()
        self.objectives = set()
        self.lander_location = None
        self.visible_from = {}
        self.calibration_targets = {}
        self.supports = {}

        for fact in self.static:
            parts = get_parts(fact)
            if match(fact, "at_lander", "*", "*"):
                self.lander_location = parts[2]
            elif match(fact, "visible_from", "*", "*"):
                objective, waypoint = parts[1], parts[2]
                if objective not in self.visible_from:
                    self.visible_from[objective] = set()
                self.visible_from[objective].add(waypoint)
            elif match(fact, "calibration_target", "*", "*"):
                camera, objective = parts[1], parts[2]
                self.calibration_targets[camera] = objective
            elif match(fact, "supports", "*", "*"):
                camera, mode = parts[1], parts[2]
                if camera not in self.supports:
                    self.supports[camera] = set()
                self.supports[camera].add(mode)

        # Extract goal conditions
        self.goal_soil_data = set()
        self.goal_rock_data = set()
        self.goal_image_data = set()

        for goal in self.goals:
            parts = get_parts(goal)
            if match(goal, "communicated_soil_data", "*"):
                self.goal_soil_data.add(parts[1])
            elif match(goal, "communicated_rock_data", "*"):
                self.goal_rock_data.add(parts[1])
            elif match(goal, "communicated_image_data", "*", "*"):
                self.goal_image_data.add((parts[1], parts[2]))

    def __call__(self, node):
        """Estimate the number of actions required to reach the goal state."""
        state = node.state

        # Initialize cost
        total_cost = 0

        # Check soil data communication
        for waypoint in self.goal_soil_data:
            if not any(match(fact, "communicated_soil_data", waypoint) for fact in state):
                total_cost += 2  # Sample and communicate

        # Check rock data communication
        for waypoint in self.goal_rock_data:
            if not any(match(fact, "communicated_rock_data", waypoint) for fact in state):
                total_cost += 2  # Sample and communicate

        # Check image data communication
        for objective, mode in self.goal_image_data:
            if not any(match(fact, "communicated_image_data", objective, mode) for fact in state):
                total_cost += 3  # Calibrate, take image, and communicate

        # Add navigation cost (estimated as 1 per waypoint)
        total_cost += len(self.goal_soil_data) + len(self.goal_rock_data) + len(self.goal_image_data)

        return total_cost
