from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(robot-at robot1 tile_0_1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class FloortileHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Floortile domain.

    # Summary
    This heuristic estimates the number of actions required to paint all tiles according to the goal conditions.
    It considers the distance robots need to travel, the number of tiles that need to be painted, and the color changes required.

    # Assumptions
    - Robots can move in four directions (up, down, left, right).
    - Robots can change their color, but this requires an action.
    - Each tile must be painted with the correct color as specified in the goal.

    # Heuristic Initialization
    - Extract the goal conditions for each tile.
    - Extract the static facts (e.g., adjacency relationships between tiles).
    - Build a map of tile positions and their adjacency relationships.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the tiles that are not yet painted with the correct color.
    2. For each unpainted tile, calculate the distance from the nearest robot.
    3. If the robot does not have the correct color, add a cost for changing the color.
    4. Sum the distances and color change costs to estimate the total number of actions required.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal conditions for each tile.
        - Static facts (adjacency relationships between tiles).
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map tiles to their adjacent tiles using "up", "down", "left", and "right" relationships.
        self.adjacency = {}
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate in ["up", "down", "left", "right"]:
                tile1, tile2 = args
                if tile1 not in self.adjacency:
                    self.adjacency[tile1] = {}
                self.adjacency[tile1][predicate] = tile2

        # Store goal conditions for each tile.
        self.goal_paintings = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "painted":
                tile, color = args
                self.goal_paintings[tile] = color

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track the current position of each robot and the color they are holding.
        robot_positions = {}
        robot_colors = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "robot-at":
                robot, tile = args
                robot_positions[robot] = tile
            elif predicate == "robot-has":
                robot, color = args
                robot_colors[robot] = color

        total_cost = 0  # Initialize action cost counter.

        for tile, goal_color in self.goal_paintings.items():
            # Check if the tile is already painted with the correct color.
            if f"(painted {tile} {goal_color})" in state:
                continue

            # Find the nearest robot to this tile.
            min_distance = float('inf')
            nearest_robot = None
            for robot, position in robot_positions.items():
                distance = self.calculate_distance(position, tile)
                if distance < min_distance:
                    min_distance = distance
                    nearest_robot = robot

            # Add the distance cost.
            total_cost += min_distance

            # Check if the robot has the correct color.
            if robot_colors[nearest_robot] != goal_color:
                total_cost += 1  # Cost for changing color.

        return total_cost

    def calculate_distance(self, start_tile, goal_tile):
        """
        Calculate the Manhattan distance between two tiles.

        - `start_tile`: The starting tile.
        - `goal_tile`: The goal tile.
        - Returns the Manhattan distance between the two tiles.
        """
        # Extract coordinates from tile names (assuming format "tile_X_Y").
        start_x, start_y = map(int, start_tile.split('_')[1:])
        goal_x, goal_y = map(int, goal_tile.split('_')[1:])

        return abs(start_x - goal_x) + abs(start_y - goal_y)
