from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(at rover1 waypoint1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class RoversHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Rovers domain.

    # Summary
    This heuristic estimates the number of actions required to achieve all goals in the Rovers domain.
    It considers the following tasks:
    - Navigating rovers to waypoints.
    - Collecting soil and rock samples.
    - Taking images of objectives.
    - Communicating data to the lander.

    # Assumptions
    - Rovers can only carry one sample at a time (soil or rock).
    - Rovers must be equipped with the appropriate instruments to perform tasks.
    - Communication with the lander requires the rover to be at a waypoint visible to the lander.

    # Heuristic Initialization
    - Extract goal conditions and static facts from the task.
    - Build data structures to store information about waypoints, rovers, objectives, and their relationships.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the current state of each rover:
       - Location (waypoint).
       - Whether it is carrying soil or rock samples.
       - Whether it has taken images of objectives.
    2. Identify the goals:
       - Communicated soil data.
       - Communicated rock data.
       - Communicated image data.
    3. For each goal, estimate the number of actions required:
       - If the goal is already achieved, no actions are needed.
       - If not, calculate the cost of navigating to the required waypoint, performing the necessary action, and communicating the data.
    4. Sum the estimated actions for all goals to get the heuristic value.
    """

    def __init__(self, task):
        """Initialize the heuristic by extracting goal conditions and static facts."""
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Extract information about waypoints, rovers, and objectives.
        self.waypoints = set()
        self.rovers = set()
        self.objectives = set()
        self.lander_location = None

        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate == "at_lander":
                self.lander_location = args[1]
            elif predicate == "visible":
                self.waypoints.add(args[0])
                self.waypoints.add(args[1])
            elif predicate == "visible_from":
                self.objectives.add(args[0])
                self.waypoints.add(args[1])
            elif predicate == "on_board":
                self.rovers.add(args[1])

        # Store goal locations for each objective and sample.
        self.goal_locations = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "communicated_soil_data":
                self.goal_locations[args[0]] = "soil"
            elif predicate == "communicated_rock_data":
                self.goal_locations[args[0]] = "rock"
            elif predicate == "communicated_image_data":
                self.goal_locations[args[0]] = "image"

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track the current state of each rover.
        rover_states = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "at":
                rover, waypoint = args
                rover_states[rover] = {"location": waypoint}
            elif predicate == "have_soil_analysis":
                rover, waypoint = args
                rover_states[rover]["soil_sample"] = waypoint
            elif predicate == "have_rock_analysis":
                rover, waypoint = args
                rover_states[rover]["rock_sample"] = waypoint
            elif predicate == "have_image":
                rover, objective, mode = args
                if "images" not in rover_states[rover]:
                    rover_states[rover]["images"] = set()
                rover_states[rover]["images"].add((objective, mode))

        total_cost = 0  # Initialize action cost counter.

        for waypoint, goal_type in self.goal_locations.items():
            if goal_type == "soil":
                # Check if soil data has already been communicated.
                if f"(communicated_soil_data {waypoint})" in state:
                    continue

                # Find a rover that has the soil sample.
                rover_with_sample = None
                for rover, state_info in rover_states.items():
                    if "soil_sample" in state_info and state_info["soil_sample"] == waypoint:
                        rover_with_sample = rover
                        break

                if rover_with_sample:
                    # Rover has the sample, navigate to lander and communicate.
                    total_cost += 1  # Navigate to lander.
                    total_cost += 1  # Communicate soil data.
                else:
                    # No rover has the sample, need to collect it.
                    total_cost += 1  # Navigate to waypoint.
                    total_cost += 1  # Sample soil.
                    total_cost += 1  # Navigate to lander.
                    total_cost += 1  # Communicate soil data.

            elif goal_type == "rock":
                # Check if rock data has already been communicated.
                if f"(communicated_rock_data {waypoint})" in state:
                    continue

                # Find a rover that has the rock sample.
                rover_with_sample = None
                for rover, state_info in rover_states.items():
                    if "rock_sample" in state_info and state_info["rock_sample"] == waypoint:
                        rover_with_sample = rover
                        break

                if rover_with_sample:
                    # Rover has the sample, navigate to lander and communicate.
                    total_cost += 1  # Navigate to lander.
                    total_cost += 1  # Communicate rock data.
                else:
                    # No rover has the sample, need to collect it.
                    total_cost += 1  # Navigate to waypoint.
                    total_cost += 1  # Sample rock.
                    total_cost += 1  # Navigate to lander.
                    total_cost += 1  # Communicate rock data.

            elif goal_type == "image":
                objective, mode = waypoint.split()
                # Check if image data has already been communicated.
                if f"(communicated_image_data {objective} {mode})" in state:
                    continue

                # Find a rover that has the image.
                rover_with_image = None
                for rover, state_info in rover_states.items():
                    if "images" in state_info and (objective, mode) in state_info["images"]:
                        rover_with_image = rover
                        break

                if rover_with_image:
                    # Rover has the image, navigate to lander and communicate.
                    total_cost += 1  # Navigate to lander.
                    total_cost += 1  # Communicate image data.
                else:
                    # No rover has the image, need to take it.
                    total_cost += 1  # Navigate to waypoint.
                    total_cost += 1  # Calibrate camera.
                    total_cost += 1  # Take image.
                    total_cost += 1  # Navigate to lander.
                    total_cost += 1  # Communicate image data.

        return total_cost
