import numpy as np
import cv2
from scipy.spatial import distance_matrix
import networkx as nx

class GridWorldWithRewards:
    def __init__(self, size=100, num_rooms=9, room_scale=1, passage_scale=1, num_extra_passages=1, seed=None):
        self.size = size
        self.num_rooms = num_rooms
        self.room_scale = room_scale
        self.passage_scale = passage_scale
        self.num_extra_passages = num_extra_passages

        if seed is not None:
            np.random.seed(seed)
        
        self.grid = np.zeros((size, size), dtype=np.uint8)
        self.room_info = {} 
        self.passage_coords = set() 
        self.individual_passages = [] 
        self.passage_centers = []  
        self.room_centers = []
        
        self._generate_world()


    def _generate_world(self):
        def add_circle_room(cx, cy, r, room_id):

            y, x = np.ogrid[:self.size, :self.size]
            mask = (x - cx)**2 + (y - cy)**2 <= r**2
        
            self.grid[mask] = 255
            coords = np.where(mask)
            self.room_info[room_id] = {
                'type': 'circular',
                'center': (cy, cx),
                'radius': r,
                'coords': list(zip(coords[0], coords[1])),
                'bbox': (max(0, cy-r), min(self.size-1, cy+r), max(0, cx-r), min(self.size-1, cx+r))
            }
        
        def add_ellipse_room(cx, cy, r, room_id):
            axes = (r, int(r * np.random.uniform(0.7, 1.2)))
            angle = np.random.randint(0, 360)
            
            temp_img = np.zeros_like(self.grid)
            cv2.ellipse(temp_img, (cx, cy), axes, angle, 0, 360, 255, -1)
            
            self.grid[temp_img > 0] = 255
            
            coords = np.where(temp_img > 0)
            self.room_info[room_id] = {
                'type': 'elliptical',
                'center': (cy, cx),
                'axes': axes,
                'angle': angle,
                'coords': list(zip(coords[0], coords[1])),
                'bbox': (coords[0].min(), coords[0].max(), coords[1].min(), coords[1].max())
            }
        
        def add_square_room(cx, cy, r, room_id):
            half = r
            top_left = (max(0, cx - half), max(0, cy - half))
            bottom_right = (min(self.size - 1, cx + half), min(self.size - 1, cy + half))
            
            cv2.rectangle(self.grid, top_left, bottom_right, 255, -1)
            
            y1, y2 = top_left[1], bottom_right[1]
            x1, x2 = top_left[0], bottom_right[0]
            coords_y, coords_x = np.meshgrid(range(y1, y2+1), range(x1, x2+1), indexing='ij')
            coords = list(zip(coords_y.flatten(), coords_x.flatten()))
            
            self.room_info[room_id] = {
                'type': 'rectangular',
                'center': (cy, cx),
                'half_size': half,
                'coords': coords,
                'bbox': (y1, y2, x1, x2),
                'corners': [(y1, x1), (y1, x2), (y2, x1), (y2, x2)]  # top-left, top-right, bottom-left, bottom-right
            }
        
        def add_polygon_room(cx, cy, r, room_id):
            points = []
            num_sides = 5 + np.random.randint(0, 3)
            for i in range(num_sides):
                angle = 2 * np.pi * i / num_sides + np.random.uniform(-0.2, 0.2)
                px = int(cx + r * np.cos(angle))
                py = int(cy + r * np.sin(angle))
                points.append([np.clip(px, 0, self.size - 1), np.clip(py, 0, self.size - 1)])
            
            temp_img = np.zeros_like(self.grid)
            cv2.fillPoly(temp_img, [np.array(points, dtype=np.int32)], 255)
            
            self.grid[temp_img > 0] = 255
            
            coords = np.where(temp_img > 0)
            self.room_info[room_id] = {
                'type': 'polygonal',
                'center': (cy, cx),
                'vertices': points,
                'coords': list(zip(coords[0], coords[1])),
                'bbox': (coords[0].min(), coords[0].max(), coords[1].min(), coords[1].max())
            }
        
        def draw_curvy_path(p1, p2, width=4, steps=10):
            points = [np.array(p1)]
            for i in range(1, steps):
                alpha = i / steps
                midpoint = (1 - alpha) * np.array(p1) + alpha * np.array(p2)
                jitter = np.random.randn(2) * 2
                points.append(midpoint + jitter)
            points.append(np.array(p2))

            path_img = np.zeros_like(self.grid)
            for i in range(len(points) - 1):
                pt1 = tuple(np.round(points[i]).astype(int))
                pt2 = tuple(np.round(points[i+1]).astype(int))
                cv2.line(path_img, pt1, pt2, 255, width)

            path_img = cv2.GaussianBlur(path_img, (5, 5), 1)
            passage_mask = path_img > 50

            passage_coords = np.where(passage_mask)
            individual_passage_coords = set()
            for y, x in zip(passage_coords[0], passage_coords[1]):
                self.passage_coords.add((y, x))
                individual_passage_coords.add((y, x))

            if individual_passage_coords:
                self.individual_passages.append(individual_passage_coords)

            self.grid[passage_mask] = 255
        
        rooms_per_row = int(np.ceil(np.sqrt(self.num_rooms)))
        sector_size = self.size // rooms_per_row
        centers = []
        
        room_functions = [add_circle_room, add_ellipse_room, add_square_room, add_polygon_room]
        
        # Ensure at least 2 of each room type
        min_per_type = 2
        guaranteed_functions = []
        for func in room_functions:
            guaranteed_functions.extend([func] * min_per_type)
        
        # Fill remaining slots randomly
        remaining_rooms = max(0, self.num_rooms - len(guaranteed_functions))
        random_functions = [np.random.choice(room_functions) for _ in range(remaining_rooms)]
        
        # Combine and shuffle
        all_functions = guaranteed_functions + random_functions
        np.random.shuffle(all_functions)
        
        # Truncate to exact number of rooms needed
        all_functions = all_functions[:self.num_rooms]
        
        room_id = 0
        for i in range(rooms_per_row):
            for j in range(rooms_per_row):
                if room_id >= len(all_functions):
                    break
                
                cx = int(i * sector_size + np.random.randint(sector_size // 4, 3 * sector_size // 4))
                cy = int(j * sector_size + np.random.randint(sector_size // 4, 3 * sector_size // 4))
                r = np.random.randint(8, 13) * self.room_scale
                
                shape_fn = all_functions[room_id]
                shape_fn(cx, cy, r, room_id)
                
                centers.append((cx, cy))
                room_id += 1
        
        self.room_centers = centers
        
        # Connect rooms with passages
        D = distance_matrix(centers, centers)
        G = nx.Graph()
        for i in range(len(centers)):
            for j in range(i + 1, len(centers)):
                G.add_edge(i, j, weight=D[i, j])
        
        mst = nx.minimum_spanning_tree(G)
        for i, j in mst.edges:
            draw_curvy_path(centers[i], centers[j], 
                          width=np.random.randint(int(3*self.passage_scale), int(6*self.passage_scale)))
        
        # Optional extra passages (at most num_extra_passages)
        for _ in range(self.num_extra_passages):
            if np.random.rand() < 0.5:
                i, j = np.random.choice(len(centers), 2, replace=False)
                draw_curvy_path(centers[i], centers[j], 
                              width=np.random.randint(int(3*self.passage_scale), int(6*self.passage_scale)))

        # Ensure borders are walls
        self.grid[0, :] = 0
        self.grid[-1, :] = 0
        self.grid[:, 0] = 0
        self.grid[:, -1] = 0
        
        # Remove border coordinates from room info and clean up passages
        self._clean_border_coords()
        self._clean_passages()
    
    def _clean_border_coords(self):
        """Remove border coordinates from room and passage data"""
        for room_info in self.room_info.values():
            room_info['coords'] = [(y, x) for y, x in room_info['coords'] 
                                 if 0 < y < self.size-1 and 0 < x < self.size-1]
        
    def _clean_passages(self):
        """Remove room coordinates from passage coordinates to get pure passages"""
        # Collect all room coordinates
        all_room_coords = set()
        for room_info in self.room_info.values():
            all_room_coords.update(room_info['coords'])
        
        # Remove room coordinates from passages
        self.passage_coords = self.passage_coords - all_room_coords

        # Filter individual passages and compute centers
        self._filter_and_center_passages(all_room_coords)


    def _filter_and_center_passages(self, all_room_coords):
        """Filter individual passages to remove room areas and compute centers"""
        filtered_passages = []
        self.passage_centers = []

        for passage_coords in self.individual_passages:
            # Remove room coordinates from this passage
            filtered_passage = passage_coords - all_room_coords

            # Only keep passages that have remaining coordinates
            if filtered_passage:
                filtered_passages.append(filtered_passage)
                # Pre-compute center of filtered passage
                center_y = sum(y for y, x in filtered_passage) / len(filtered_passage)
                center_x = sum(x for y, x in filtered_passage) / len(filtered_passage)
                self.passage_centers.append((center_y, center_x))

        # Replace with filtered passages
        self.individual_passages = filtered_passages

    # ORIGINAL REWARD FUNCTION GENERATORS
    
    def get_top_region_reward(self, fraction=0.33, *args, **kwargs):
        """Reward for being in the top fraction of the environment"""
        threshold = int(self.size * fraction)
        
        def reward_func(row, col):
            return 1.0 if row <= threshold else 0.0
        
        return reward_func
    
    def get_bottom_region_reward(self, fraction=0.33, *args, **kwargs):
        """Reward for being in the bottom fraction of the environment"""
        threshold = int(self.size * (1 - fraction))
        
        def reward_func(row, col):
            return 1.0 if row >= threshold else 0.0
        
        return reward_func
    
    def get_left_region_reward(self, fraction=0.33, *args, **kwargs):
        """Reward for being in the left fraction of the environment"""
        threshold = int(self.size * fraction)
        
        def reward_func(row, col):
            return 1.0 if col <= threshold else 0.0
        
        return reward_func
    
    def get_right_region_reward(self, fraction=0.33, *args, **kwargs):
        """Reward for being in the right fraction of the environment"""
        threshold = int(self.size * (1 - fraction))
        
        def reward_func(row, col):
            return 1.0 if col >= threshold else 0.0
        
        return reward_func
    
    def get_passage_reward(self, *args, **kwargs):
        """Reward for being in a passage between rooms"""
        def reward_func(row, col):
            return 1.0 if (row, col) in self.passage_coords else 0.0
        
        return reward_func
    
    def get_room_reward(self, room_type=None, *args, **kwargs):
        """Reward for being in any room or specific room type"""
        room_coords = set()
        
        for room_info in self.room_info.values():
            if room_type is None or room_info['type'] == room_type:
                room_coords.update(room_info['coords'])
        
        def reward_func(row, col):
            return 1.0 if (row, col) in room_coords else 0.0
        
        return reward_func
    
    def get_circular_room_reward(self, *args, **kwargs):
        """Reward for being in a circular room"""
        return self.get_room_reward('circular')
    def get_elliptical_room_reward(self, *args, **kwargs):
        """Reward for being in a circular room"""
        return self.get_room_reward('elliptical')  
    def get_rectangular_room_reward(self, *args, **kwargs):
        """Reward for being in a rectangular room"""
        return self.get_room_reward('rectangular')
    def get_polygonal_room_reward(self, *args, **kwargs):
        """Reward for being in a circular room"""
        return self.get_room_reward('polygonal')
      
    def get_room_center_reward(self, radius=5, *args, **kwargs):
        """Reward for being near room centers"""
        def reward_func(row, col):
            for room_info in self.room_info.values():
                center_y, center_x = room_info['center']
                if (row - center_y)**2 + (col - center_x)**2 <= radius**2:
                    return 1.0
            return 0.0
        
        return reward_func
    
    def get_room_corner_reward(self, corner_type='any', corner_size=3, *args, **kwargs):
        """Reward for being in corners of rectangular rooms"""
        corner_coords = set()
        
        for room_info in self.room_info.values():
            if room_info['type'] == 'rectangular':
                bbox = room_info['bbox']
                y1, y2, x1, x2 = bbox
                
                # Define corner regions instead of single points
                corners_regions = {
                    'top-left': (y1, min(y1 + corner_size, y2), x1, min(x1 + corner_size, x2)),
                    'top-right': (y1, min(y1 + corner_size, y2), max(x2 - corner_size, x1), x2),
                    'bottom-left': (max(y2 - corner_size, y1), y2, x1, min(x1 + corner_size, x2)),
                    'bottom-right': (max(y2 - corner_size, y1), y2, max(x2 - corner_size, x1), x2)
                }
                
                if corner_type == 'any':
                    for region in corners_regions.values():
                        ry1, ry2, rx1, rx2 = region
                        for y in range(ry1, ry2 + 1):
                            for x in range(rx1, rx2 + 1):
                                if (y, x) in room_info['coords']:  # Only include actual room coordinates
                                    corner_coords.add((y, x))
                else:
                    if corner_type in corners_regions:
                        ry1, ry2, rx1, rx2 = corners_regions[corner_type]
                        for y in range(ry1, ry2 + 1):
                            for x in range(rx1, rx2 + 1):
                                if (y, x) in room_info['coords']:  # Only include actual room coordinates
                                    corner_coords.add((y, x))
        
        def reward_func(row, col):
            return 1.0 if (row, col) in corner_coords else 0.0
        
        return reward_func
    
    def get_wall_reward(self, step_size=4, wall_side='any', *args, **kwargs):
        """Reward for being adjacent to walls"""
        def reward_func(row, col):
            if self.grid[row, col] == 0:  # Not walkable
                return 0.0
            
            # Check adjacent cells for walls
            directions = {
                'top': (-step_size, 0),
                'bottom': (step_size, 0),
                'left': (0, -step_size),
                'right': (0, step_size)
            }
            
            if wall_side == 'any':
                # Check if adjacent to any wall
                for dr, dc in directions.values():
                    nr, nc = row + dr, col + dc
                    if (nr < 0 or nr >= self.size or nc < 0 or nc >= self.size or 
                        self.grid[nr, nc] == 0):
                        return 1.0
                return 0.0
            else:
                # Check specific wall side
                if wall_side in directions:
                    dr, dc = directions[wall_side]
                    nr, nc = row + dr, col + dc
                    if (nr < 0 or nr >= self.size or nc < 0 or nc >= self.size or 
                        self.grid[nr, nc] == 0):
                        return 1.0
                return 0.0
        
        return reward_func
    
    def get_distance_reward(self, target_coords, max_distance=None, *args, **kwargs):
        """Reward based on distance to target coordinates (closer = higher reward)"""
        if max_distance is None:
            max_distance = self.size
        
        target_row, target_col = target_coords
        
        def reward_func(row, col):
            distance = np.sqrt((row - target_row)**2 + (col - target_col)**2)
            return max(0, 1 - distance / max_distance)
        
        return reward_func
    
    def get_room_boundary_reward(self, step_size=4, *args, **kwargs):
        """Reward for being on the boundary of rooms (but not passages)"""
        room_coords = set()
        for room_info in self.room_info.values():
            room_coords.update(room_info['coords'])
        
        def reward_func(row, col):
            if (row, col) not in room_coords or (row, col) in self.passage_coords:
                return 0.0
            
            # Check if on boundary (adjacent to non-room cell)
            for dr, dc in [(-step_size,0), (step_size,0), (0,-step_size), (0,step_size)]:
                nr, nc = row + dr, col + dc
                if (nr, nc) not in room_coords:
                    return 1.0
            return 0.0
        
        return reward_func
    

    def get_closest_room_reward(self, target_row, target_col, room_type=None, *args, **kwargs):
        """Reward for being in the closest room of the specified type (or any room if None)"""
        # Get rooms of the specified type
        target_rooms = []
        for room_id, room_info in self.room_info.items():
            if room_type is None or room_info['type'] == room_type:
                target_rooms.append((room_id, room_info))
        
        if not target_rooms:
            # No rooms of this type, return zero reward everywhere
            def reward_func(row, col):
                return 0.0
            return reward_func
        
        def reward_func(row, col):
            # Find closest room
            min_distance = float('inf')
            closest_room_id = None
            
            for room_id, room_info in target_rooms:
                center_y, center_x = room_info['center']
                distance = np.sqrt((target_row - center_y)**2 + (target_col - center_x)**2)
                if distance < min_distance:
                    min_distance = distance
                    closest_room_id = room_id
            
            # Check if current position is in the closest room
            if closest_room_id is not None:
                closest_room_coords = set(self.room_info[closest_room_id]['coords'])
                return 1.0 if (row, col) in closest_room_coords else 0.0
            
            return 0.0
        
        return reward_func
    
    def get_closest_circular_room_reward(self, target_row, target_col, *args, **kwargs):
        """Reward for being in the closest circular room"""
        return self.get_closest_room_reward(target_row, target_col, 'circular')
    
    def get_closest_rectangular_room_reward(self, target_row, target_col, *args, **kwargs):
        """Reward for being in the closest rectangular room"""
        return self.get_closest_room_reward(target_row, target_col, 'rectangular')
    
    def get_closest_elliptical_room_reward(self, target_row, target_col, *args, **kwargs):
        """Reward for being in the closest elliptical room"""
        return self.get_closest_room_reward(target_row, target_col, 'elliptical')
    
    def get_round_room_reward(self, *args, **kwargs):
        """Reward for being in a round room (circular or elliptical)"""
        round_coords = set()
        
        for room_info in self.room_info.values():
            if room_info['type'] in ['circular', 'elliptical']:
                round_coords.update(room_info['coords'])
        
        def reward_func(row, col):
            return 1.0 if (row, col) in round_coords else 0.0
        
        return reward_func
    
    def get_closest_round_room_reward(self, target_row, target_col, *args, **kwargs):
        """Reward for being in the closest round room (circular or elliptical)"""
        # Get rooms of circular or elliptical type
        round_rooms = []
        for room_id, room_info in self.room_info.items():
            if room_info['type'] in ['circular', 'elliptical']:
                round_rooms.append((room_id, room_info))
        
        if not round_rooms:
            # No round rooms, return zero reward everywhere
            def reward_func(row, col):
                return 0.0
            return reward_func
        
        def reward_func(row, col):
            # Find closest round room
            min_distance = float('inf')
            closest_room_id = None
            
            for room_id, room_info in round_rooms:
                center_y, center_x = room_info['center']
                distance = np.sqrt((target_row - center_y)**2 + (target_col - center_x)**2)
                if distance < min_distance:
                    min_distance = distance
                    closest_room_id = room_id
            
            # Check if current position is in the closest round room
            if closest_room_id is not None:
                closest_room_coords = set(self.room_info[closest_room_id]['coords'])
                return 1.0 if (row, col) in closest_room_coords else 0.0
            
            return 0.0
        
        return reward_func
    
    def get_closest_polygonal_room_reward(self, target_row, target_col, *args, **kwargs):
        """Reward for being in the closest polygonal room"""
        return self.get_closest_room_reward(target_row, target_col, 'polygonal')
    
    def get_closest_room_center_reward(self, target_row, target_col, radius=5, *args, **kwargs):
        """Reward for being near the closest room center"""
        def reward_func(row, col):
            # Find closest room center
            min_distance = float('inf')
            closest_center = None
            
            for room_info in self.room_info.values():
                center_y, center_x = room_info['center']
                distance = np.sqrt((target_row - center_y)**2 + (target_col - center_x)**2)
                if distance < min_distance:
                    min_distance = distance
                    closest_center = (center_y, center_x)
            
            # Check if within radius of closest center
            if closest_center is not None:
                center_y, center_x = closest_center
                if (row - center_y)**2 + (col - center_x)**2 <= radius**2:
                    return 1.0
            
            return 0.0
        
        return reward_func
    
    def get_closest_room_corner_reward(self, target_row, target_col, corner_type='any', corner_size=3, *args, **kwargs):
        """Reward for being in corners of the closest rectangular room"""
        # Get all rectangular rooms
        rect_rooms = []
        for room_id, room_info in self.room_info.items():
            if room_info['type'] == 'rectangular':
                rect_rooms.append((room_id, room_info))
        
        if not rect_rooms:
            def reward_func(row, col):
                return 0.0
            return reward_func
        
        def reward_func(row, col):
            # Find closest rectangular room
            min_distance = float('inf')
            closest_room = None
            
            for room_id, room_info in rect_rooms:
                center_y, center_x = room_info['center']
                distance = np.sqrt((target_row - center_y)**2 + (target_col - center_x)**2)
                if distance < min_distance:
                    min_distance = distance
                    closest_room = room_info
            
            if closest_room is None:
                return 0.0
            
            # Get corner coordinates for the closest room
            bbox = closest_room['bbox']
            y1, y2, x1, x2 = bbox
            
            corners_regions = {
                'top-left': (y1, min(y1 + corner_size, y2), x1, min(x1 + corner_size, x2)),
                'top-right': (y1, min(y1 + corner_size, y2), max(x2 - corner_size, x1), x2),
                'bottom-left': (max(y2 - corner_size, y1), y2, x1, min(x1 + corner_size, x2)),
                'bottom-right': (max(y2 - corner_size, y1), y2, max(x2 - corner_size, x1), x2)
            }
            
            corner_coords = set()
            if corner_type == 'any':
                for region in corners_regions.values():
                    ry1, ry2, rx1, rx2 = region
                    for y in range(ry1, ry2 + 1):
                        for x in range(rx1, rx2 + 1):
                            if (y, x) in closest_room['coords']:
                                corner_coords.add((y, x))
            else:
                if corner_type in corners_regions:
                    ry1, ry2, rx1, rx2 = corners_regions[corner_type]
                    for y in range(ry1, ry2 + 1):
                        for x in range(rx1, rx2 + 1):
                            if (y, x) in closest_room['coords']:
                                corner_coords.add((y, x))
            
            return 1.0 if (row, col) in corner_coords else 0.0
        
        return reward_func

    def get_closest_passage_reward(self, target_row, target_col, *args, **kwargs):
        """Reward for being in the closest passage to agent position"""
        if not self.passage_centers:
            def reward_func(row, col):
                return 0.0
            return reward_func

        def reward_func(row, col):
            # Find closest passage using pre-computed centers
            min_distance = float('inf')
            closest_passage_idx = None

            for i, (center_y, center_x) in enumerate(self.passage_centers):
                # Distance to passage center
                passage_dist = np.sqrt((target_row - center_y)**2 + (target_col - center_x)**2)

                if passage_dist < min_distance:
                    min_distance = passage_dist
                    closest_passage_idx = i

            # Check if current position is in the closest passage
            if closest_passage_idx is not None:
                closest_passage_coords = self.individual_passages[closest_passage_idx]
                return 1.0 if (row, col) in closest_passage_coords else 0.0
            return 0.0

        return reward_func

    def get_closest_wall_reward(self, step_size=4, wall_side='any', *args, **kwargs):
        """Reward for being next to the closest wall"""
        def reward_func(row, col):
            if self.grid[row, col] == 0:  # Not walkable
                return 0.0
            
            directions = {
                'top': (-step_size, 0),
                'bottom': (step_size, 0), 
                'left': (0, -step_size),
                'right': (0, step_size)
            }
            
            # Find all adjacent wall positions
            adjacent_walls = []
            check_directions = directions.values() if wall_side == 'any' else [directions[wall_side]]
            
            for dr, dc in check_directions:
                nr, nc = row + dr, col + dc
                if (nr < 0 or nr >= self.size or nc < 0 or nc >= self.size or 
                    self.grid[nr, nc] == 0):
                    adjacent_walls.append((nr, nc))
            
            if not adjacent_walls:
                return 0.0
            
            # For simplicity, if there are any adjacent walls, give reward
            # (since we're already at the "closest" wall by being adjacent)
            return 1.0
        
        return reward_func
    
    def get_closest_disjoint_room_reward(self, room_a, room_b, *args, **kwargs):
        """Reward for being in a room that is disjoint from another room"""
        if room_a not in self.room_info or room_b not in self.room_info:
            def reward_func(row, col):
                return 0.0
            return reward_func
        
        coords_a = set(self.room_info[room_a]['coords'])
        coords_b = set(self.room_info[room_b]['coords'])
        
        def reward_func(row, col):
            distance_a = np.sqrt((row - self.room_info[room_a]['center'][0])**2 + (col - self.room_info[room_a]['center'][1])**2)
            distance_b = np.sqrt((row - self.room_info[room_b]['center'][0])**2 + (col - self.room_info[room_b]['center'][1])**2)

            if ((distance_a < distance_b) and ((row, col) in coords_a)) or ((distance_a >= distance_b) and ((row, col) in coords_b)):
                return 1.0
            return 0.0
        
        return reward_func
    
    def get_disjoint_room_reward(self, room_a, room_b, *args, **kwargs):
        """Reward for being in a room that is disjoint from another room"""
        if room_a not in self.room_info or room_b not in self.room_info:
            def reward_func(row, col):
                return 0.0
            return reward_func
        
        coords_a = set(self.room_info[room_a]['coords'])
        coords_b = set(self.room_info[room_b]['coords'])
        
        def reward_func(row, col):
            if (row, col) in coords_a or (row, col) in coords_b:
                return 1.0
            return 0.0
        
        return reward_func
    
    def get_specific_room_reward(self, room_a, *args, **kwargs):
        """Reward for being in a specific room identified by room_id"""
        if room_a not in self.room_info:
            def reward_func(row, col):
                return 0.0
            return reward_func
        
        target_coords = set(self.room_info[room_a]['coords'])
        
        def reward_func(row, col):
            return 1.0 if (row, col) in target_coords else 0.0
        
        return reward_func
    
