from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic

def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()

def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(robot-at robot1 tile_0_1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))

class FloortileHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Floortile domain.

    # Summary
    This heuristic estimates the number of actions required to paint all tiles according to the goal conditions.
    It considers the current state of the tiles, the robots' positions, and the colors they are carrying.

    # Assumptions
    - Robots can move between adjacent tiles (up, down, left, right).
    - Robots can change their color if needed.
    - Each tile must be painted with the correct color as specified in the goal.

    # Heuristic Initialization
    - Extract the goal conditions for each tile.
    - Extract static information about tile adjacencies and available colors.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify tiles that are not yet painted with the correct color.
    2. For each such tile, calculate the minimum number of moves required for a robot to reach it.
    3. If the robot needs to change its color to match the goal, add an action for changing the color.
    4. Sum the actions required for all tiles to get the total heuristic value.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal conditions for each tile.
        - Static facts (tile adjacencies and available colors).
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map tiles to their goal colors.
        self.goal_colors = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "painted":
                tile, color = args
                self.goal_colors[tile] = color

        # Extract tile adjacencies.
        self.adjacent = {}
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate in ["up", "down", "left", "right"]:
                tile1, tile2 = args
                if tile1 not in self.adjacent:
                    self.adjacent[tile1] = []
                self.adjacent[tile1].append((predicate, tile2))

        # Extract available colors.
        self.available_colors = {
            get_parts(fact)[1] for fact in static_facts if match(fact, "available-color", "*")
        }

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track the current color of each robot.
        robot_colors = {}
        for fact in state:
            if match(fact, "robot-has", "*", "*"):
                robot, color = get_parts(fact)[1:]
                robot_colors[robot] = color

        # Track the current position of each robot.
        robot_positions = {}
        for fact in state:
            if match(fact, "robot-at", "*", "*"):
                robot, tile = get_parts(fact)[1:]
                robot_positions[robot] = tile

        # Track which tiles are already painted with the correct color.
        correct_painted = {
            tile for tile, color in self.goal_colors.items()
            if f"(painted {tile} {color})" in state
        }

        total_cost = 0  # Initialize action cost counter.

        for tile, goal_color in self.goal_colors.items():
            if tile in correct_painted:
                continue  # Tile is already correctly painted.

            # Find the closest robot to this tile.
            min_distance = float('inf')
            closest_robot = None
            for robot, position in robot_positions.items():
                distance = self._calculate_distance(position, tile)
                if distance < min_distance:
                    min_distance = distance
                    closest_robot = robot

            # Add the cost of moving the robot to the tile.
            total_cost += min_distance

            # Check if the robot needs to change its color.
            if robot_colors[closest_robot] != goal_color:
                total_cost += 1  # Cost of changing color.

            # Add the cost of painting the tile.
            total_cost += 1

        return total_cost

    def _calculate_distance(self, start_tile, goal_tile):
        """
        Calculate the minimum number of moves required to go from `start_tile` to `goal_tile`.
        Uses a simple BFS approach to find the shortest path.
        """
        from collections import deque

        if start_tile == goal_tile:
            return 0

        visited = set()
        queue = deque([(start_tile, 0)])

        while queue:
            current_tile, distance = queue.popleft()
            if current_tile == goal_tile:
                return distance

            visited.add(current_tile)
            for _, neighbor in self.adjacent.get(current_tile, []):
                if neighbor not in visited:
                    queue.append((neighbor, distance + 1))

        return float('inf')  # If no path is found (should not happen in valid instances).
