from fnmatch import fnmatch
from heuristics.heuristic_base import Heuristic


def get_parts(fact):
    """Extract the components of a PDDL fact by removing parentheses and splitting the string."""
    return fact[1:-1].split()


def match(fact, *args):
    """
    Check if a PDDL fact matches a given pattern.

    - `fact`: The complete fact as a string, e.g., "(robot-at robot1 tile_0_1)".
    - `args`: The expected pattern (wildcards `*` allowed).
    - Returns `True` if the fact matches the pattern, else `False`.
    """
    parts = get_parts(fact)
    return all(fnmatch(part, arg) for part, arg in zip(parts, args))


class FloortileHeuristic(Heuristic):
    """
    A domain-dependent heuristic for the Floortile domain.

    # Summary
    This heuristic estimates the number of actions required to paint all tiles according to the goal configuration.
    It considers the number of tiles that still need to be painted, the distance of robots to these tiles, and the
    number of color changes required.

    # Assumptions
    - Robots can move in four directions (up, down, left, right) and can paint adjacent tiles.
    - Each robot can carry one color at a time, and color changes are allowed.
    - The heuristic assumes that robots can move freely without blocking each other.

    # Heuristic Initialization
    - Extract the goal conditions for each tile (i.e., the required color for each tile).
    - Extract the static facts (e.g., adjacency relationships between tiles) to compute distances.

    # Step-By-Step Thinking for Computing Heuristic
    1. Identify the tiles that are not yet painted in the goal color.
    2. For each unpainted tile, compute the minimum distance from any robot to that tile.
    3. If the robot does not have the required color, add a cost for changing the color.
    4. Sum the distances and color change costs to estimate the total number of actions required.
    """

    def __init__(self, task):
        """
        Initialize the heuristic by extracting:
        - Goal conditions for each tile.
        - Static facts (adjacency relationships between tiles).
        """
        self.goals = task.goals  # Goal conditions.
        static_facts = task.static  # Facts that are not affected by actions.

        # Map each tile to its required color in the goal state.
        self.goal_colors = {}
        for goal in self.goals:
            predicate, *args = get_parts(goal)
            if predicate == "painted":
                tile, color = args
                self.goal_colors[tile] = color

        # Extract adjacency relationships between tiles.
        self.adjacent = {}
        for fact in static_facts:
            predicate, *args = get_parts(fact)
            if predicate in ["up", "down", "left", "right"]:
                tile1, tile2 = args
                if tile1 not in self.adjacent:
                    self.adjacent[tile1] = []
                self.adjacent[tile1].append(tile2)

    def __call__(self, node):
        """Compute an estimate of the minimal number of required actions."""
        state = node.state  # Current world state.

        # Track the current color of each robot.
        robot_colors = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "robot-has":
                robot, color = args
                robot_colors[robot] = color

        # Track the current location of each robot.
        robot_locations = {}
        for fact in state:
            predicate, *args = get_parts(fact)
            if predicate == "robot-at":
                robot, tile = args
                robot_locations[robot] = tile

        # Identify tiles that are not yet painted in the goal color.
        unpainted_tiles = []
        for tile, goal_color in self.goal_colors.items():
            if f"(painted {tile} {goal_color})" not in state:
                unpainted_tiles.append(tile)

        total_cost = 0  # Initialize action cost counter.

        for tile in unpainted_tiles:
            goal_color = self.goal_colors[tile]

            # Find the closest robot to this tile.
            min_distance = float("inf")
            closest_robot = None
            for robot, location in robot_locations.items():
                distance = self._compute_distance(location, tile)
                if distance < min_distance:
                    min_distance = distance
                    closest_robot = robot

            # Add the distance cost.
            total_cost += min_distance

            # Add a cost for changing the color if necessary.
            if robot_colors[closest_robot] != goal_color:
                total_cost += 1  # Cost for changing color.

        return total_cost

    def _compute_distance(self, start_tile, goal_tile):
        """
        Compute the minimum number of moves required to go from `start_tile` to `goal_tile`.
        Uses a simple breadth-first search (BFS) to find the shortest path.
        """
        from collections import deque

        if start_tile == goal_tile:
            return 0

        visited = set()
        queue = deque([(start_tile, 0)])

        while queue:
            current_tile, distance = queue.popleft()
            if current_tile == goal_tile:
                return distance

            if current_tile in visited:
                continue
            visited.add(current_tile)

            for neighbor in self.adjacent.get(current_tile, []):
                queue.append((neighbor, distance + 1))

        return float("inf")  # If no path exists (should not happen in valid instances).
