import math
import os
import cv2
import numpy as np
import skimage.morphology
from PIL import Image
from torchvision import transforms
from typing import Dict
from envs.utils.fmm_planner import FMMPlanner
from envs.habitat.objectgoal_env import ObjectGoal_Env
from envs.habitat.objectgoal_env21 import ObjectGoal_Env21
from agents.utils.semantic_prediction import SemanticPredMaskRCNN
from constants import color_palette
import envs.utils.pose as pu
import agents.utils.visualization as vu
from agents.utils.yolo07 import YOLOSegmenter
from RedNet.RedNet_model import load_rednet
from constants import mp_categories_mapping
import torch
from collections import deque
from constants import categories2channels


class Sem_Exp_Env_Agent(ObjectGoal_Env21):
    """Semantic Exploration Environment Agent class. 
    A separate Sem_Exp_Env_Agent object is used for each environment thread.
    """

    def __init__(self, args, rank, config_env, dataset):
        """
        Initialize the semantic exploration agent.

        Args:
            args: Configuration arguments
            rank: Process rank/ID
            config_env: Environment configuration
            dataset: Dataset object
        """
        self.args = args
        super().__init__(args, rank, config_env, dataset)

        # Initialize transform for RGB observations
        self.res = transforms.Compose(
            [transforms.ToPILImage(),
             transforms.Resize((args.frame_height, args.frame_width),
                               interpolation=Image.NEAREST)])

        # Initialize semantic segmentation prediction model
        if args.sem_gpu_id == -1:
            args.sem_gpu_id = config_env.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID

        self.device = args.device
        self.MaskRcnn = SemanticPredMaskRCNN(args)
        
        # Initialize YOLO segmenter with configurable thresholds
        self.yolo_seg = YOLOSegmenter(
            conf_thresh=args.yolo_conf_thresh,
            mask_thresh=args.yolo_mask_thresh,
            binary_thresh=args.yolo_binary_thresh
        )
        
        # TV-specific parameters
        self.tv_black_pixel_threshold = getattr(
            args, 'tv_black_pixel_threshold', 0.9)
        self.tv_yolo_conf_thresh = args.tv_yolo_conf_thresh
        self.tv_yolo_mask_thresh = args.tv_yolo_mask_thresh
        self.tv_yolo_binary_thresh = args.tv_yolo_binary_thresh

        # Planning initializations
        self.selem = skimage.morphology.disk(3)

        # State tracking variables
        self.obs = None
        self.obs_shape = None
        self.collision_map = None
        self.visited = None
        self.visited_vis = None
        self.col_width = None
        self.curr_loc = None
        self.last_loc = None
        self.last_action = None
        self.count_forward_actions = None
        self.info = {}
        self.replan_count = 0
        self.collision_n = 0
        self.kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))

        # Visualization setup
        if args.visualize or args.print_images:
            self.vis_image = None
            self.rgb_vis_with_bbox = None
            self.rgb_vis_without_bbox = None
            self.vis_rgb_width = args.env_frame_width
            self.vis_rgb_height = args.env_frame_height
            self.vis_sem_map_size = args.env_frame_height  # Square semantic map

    def reset(self):
        """
        Reset the environment to start a new episode.

        Returns:
            tuple: (obs, info, rgbd)
            - obs: Preprocessed observations
            - info: Episode information dictionary
            - rgbd: Raw RGBD observation
        """
        args = self.args
        obs, info = super().reset()  # obs.shape=(4, 480, 640)
        
        # Store raw RGBD data
        self.rgbd = obs.transpose(1, 2, 0)
        
        # Preprocess observations
        obs = self._preprocess_obs(obs)  # obs.shape=(20, 120, 160)
        self.obs_shape = obs.shape

        # Initialize episode tracking variables
        map_shape = (args.map_size_cm // args.map_resolution,
                     args.map_size_cm // args.map_resolution)
        self.collision_map = np.zeros(map_shape)
        self.visited = np.zeros(map_shape)
        self.visited_vis = np.zeros(map_shape)
        self.col_width = 1
        self.curr_loc = [args.map_size_cm / 100.0 / 2.0,
                         args.map_size_cm / 100.0 / 2.0, 0.]
        self.last_action = None
        self.history_position = deque(maxlen=10)

        # Initialize visualization if enabled
        if args.visualize or args.print_images:
            self.vis_image = vu.init_vis_image(self.goal_name,
                                               self.vis_rgb_width,
                                               self.vis_rgb_height,
                                               self.vis_sem_map_size)

        # Initialize trajectory mask
        trajectory_mask = np.zeros(map_shape, dtype=np.uint8)
        if hasattr(self, 'visited_vis') and self.visited_vis is not None:
            trajectory_mask = (self.visited_vis > 0).astype(np.uint8)

        # Update info dictionary
        info['trajectory_mask'] = trajectory_mask
        self.info = info

        return obs, self.info, self.rgbd

    def plan_act_and_preprocess(self, planner_inputs):
        """
        Handle planning, action execution and observation preprocessing.

        Args:
            planner_inputs (dict): Planning inputs containing:
                - map_pred: Map prediction array
                - goal: Goal locations array
                - pose_pred: Agent pose prediction
                - found_goal: Whether goal was found
                - other planning parameters

        Returns:
            tuple: (obs, done, info, rgbd)
            - obs: Preprocessed observations
            - done: Episode completion flag
            - info: Updated info dictionary
            - rgbd: Raw RGBD observation
        """
        # Handle rotation actions
        if planner_inputs['rotate_agent'][0]:
            if planner_inputs['rotate_agent'][1] == "right":
                action = {'action': 3}
            elif planner_inputs['rotate_agent'][1] == "left":
                action = {'action': 2}
            else:
                pass

            # Execute action and get new observations
            obs, done, info = super().step(action)
            
            # Update visualization and save temporary files
            if self.args.visualize or self.args.print_images:
                self._visualize(planner_inputs)
                self._save_tmp_files(
                    self.rgbd, planner_inputs['candidate_goal_masks'], None)
                self._save_trajectory_mask(planner_inputs['map_pred'].shape)
            
            # Update RGBD and trajectory mask
            self.rgbd = obs.transpose(1, 2, 0)
            trajectory_mask = np.zeros(
                planner_inputs['map_pred'].shape, dtype=np.uint8)
            if hasattr(self, 'visited_vis') and self.visited_vis is not None:
                trajectory_mask = (self.visited_vis > 0).astype(np.uint8)
            info['trajectory_mask'] = trajectory_mask
            self.info = info

            # Preprocess observations and return
            obs = self._preprocess_obs(obs)
            return obs, done, self.info, self.rgbd

        # Regular planning and action execution
        fmm_stop, action = self._plan(planner_inputs)
        
        # Update visualization RGB if provided
        if planner_inputs['rgb_vis'] is not None:
            self.rgb_vis = planner_inputs['rgb_vis'][:, :, ::-1]  # Convert RGB to BGR

        # Update visualization if enabled
        if self.args.visualize or self.args.print_images:
            self._visualize(planner_inputs)

        if action >= 0:  # Valid action (0=stop)
            # Execute action
            action = {'action': action}
            obs, done, info = super().step(action)
            info['fmm_stop'] = fmm_stop
            rgbd = obs.transpose(1, 2, 0)

            # Preprocess observations
            obs = self._preprocess_obs(obs)
            self.last_action = action['action']
            self.obs = obs

            # Save temporary files if visualization enabled
            if self.args.visualize or self.args.print_images:
                self._save_tmp_files(
                    rgbd, planner_inputs['candidate_goal_masks'], planner_inputs)

            # Update trajectory mask
            trajectory_mask = np.zeros(
                planner_inputs['map_pred'].shape, dtype=np.uint8)
            if hasattr(self, 'visited_vis') and self.visited_vis is not None:
                trajectory_mask = (self.visited_vis > 0).astype(np.uint8)
            info['trajectory_mask'] = trajectory_mask
            self.info = info
            self.rgbd = rgbd

            return obs, done, self.info, rgbd
        else:
            # Handle invalid action case
            self.last_action = None
            self.info["sensor_pose"] = [0., 0., 0.]
            trajectory_mask = np.zeros(
                planner_inputs['map_pred'].shape, dtype=np.uint8)
            if hasattr(self, 'visited_vis') and self.visited_vis is not None:
                trajectory_mask = (self.visited_vis > 0).astype(np.uint8)
            self.info['trajectory_mask'] = trajectory_mask
            return np.zeros(self.obs_shape), False, self.info, np.zeros(rgbd.shape)

    def _is_stuck(self):
        """
        Check if agent is stuck based on recent movement history.

        Returns:
            bool: True if agent appears stuck, False otherwise
        """
        if len(self.history_position) < 10:
            return False
            
        # Count forward movements in recent history
        num_forward = sum(1 for action in self.history_position if action == 1)
        
        # Consider stuck if too few forward movements
        return num_forward < 3

    def _plan(self, planner_inputs):
        """
        Plan next action based on current map and goal information.

        Args:
            planner_inputs (dict):
                - map_pred: Full map prediction (480x480)
                - goal: Goal locations (480x480)
                - pose_pred: Agent pose (x,y,orientation)
                - found_goal: Whether goal object was found

        Returns:
            tuple: (fmm_stop, action)
            - fmm_stop: Whether FMM planner indicates stop
            - action: Selected action ID
        """
        args = self.args
        self.last_loc = self.curr_loc

        # Get map and goal information
        map_pred = np.rint(planner_inputs['map_pred'])
        goal = planner_inputs['goal']

        # Get current pose
        start_x, start_y, start_o = planner_inputs['pose_pred'][:3]
        self.curr_loc = [start_x, start_y, start_o]

        # Convert pose to map coordinates
        r, c = start_y, start_x
        start = [int(r * 100.0 / args.map_resolution),
                 int(c * 100.0 / args.map_resolution)]
        start = pu.threshold_poses(start, map_pred.shape)

        # Mark current position as visited
        self.visited[start[0]:start[0] + 1, start[1]:start[1] + 1] = 1

        # Update trajectory visualization if enabled
        if args.visualize or args.print_images:
            # Calculate previous position
            last_start_x, last_start_y = self.last_loc[0], self.last_loc[1]
            r, c = last_start_y, last_start_x
            last_start = [int(r * 100.0 / args.map_resolution),
                          int(c * 100.0 / args.map_resolution)]
            last_start = pu.threshold_poses(last_start, map_pred.shape)

            # Draw trajectory line
            self.visited_vis = vu.draw_line(
                last_start, start, self.visited_vis)

            # Save trajectory mask
            self._save_trajectory_mask(map_pred.shape)

        # Collision detection
        if self.last_action == 1:  # Forward action
            x1, y1, t1 = self.last_loc
            x2, y2, _ = self.curr_loc
            buf = 4
            length = 2

            # Check if agent moved significantly
            if abs(x1 - x2) < 0.05 and abs(y1 - y2) < 0.05:
                self.col_width += 2
                if self.col_width == 7:
                    length = 4
                    buf = 3
                self.col_width = min(self.col_width, 5)
            else:
                self.col_width = 1

            # Mark collision area if movement below threshold
            dist = pu.get_l2_distance(x1, x2, y1, y2)
            if dist < args.collision_threshold:
                width = self.col_width
                for i in range(length):
                    for j in range(width):
                        wx = x1 + 0.05 * \
                            ((i + buf) * np.cos(np.deg2rad(t1))
                             + (j - width // 2) * np.sin(np.deg2rad(t1)))
                        wy = y1 + 0.05 * \
                            ((i + buf) * np.sin(np.deg2rad(t1))
                             - (j - width // 2) * np.cos(np.deg2rad(t1)))
                        r, c = wy, wx
                        r, c = int(r * 100 / args.map_resolution), \
                            int(c * 100 / args.map_resolution)
                        [r, c] = pu.threshold_poses([r, c],
                                                    self.collision_map.shape)
                        self.collision_map[r, c] = 1

        # Get short-term goal using FMM planner
        stg, fmm_stop = self._get_stg_global(map_pred, start, np.copy(goal))

        # Determine action based on goal direction
        if fmm_stop and planner_inputs['found_goal'] == 1:
            action = 0  # Stop
        else:
            (stg_x, stg_y) = stg
            angle_st_goal = math.degrees(math.atan2(stg_x - start[0],
                                                    stg_y - start[1]))
            angle_agent = (start_o) % 360.0
            if angle_agent > 180:
                angle_agent -= 360

            relative_angle = (angle_agent - angle_st_goal) % 360.0
            if relative_angle > 180:
                relative_angle -= 360

            if relative_angle > self.args.turn_angle / 2.:
                action = 3  # Right
            elif relative_angle < -self.args.turn_angle / 2.:
                action = 2  # Left
            else:
                action = 1  # Forward

        return fmm_stop, action

    def _get_stg_global(self, grid, start, goal):
        """
        Compute short-term goal using Fast Marching Method on global map.

        Args:
            grid: Global map array
            start: Starting position [x,y]
            goal: Goal map array

        Returns:
            tuple: (stg, fmm_stop)
            - stg: Short-term goal coordinates
            - fmm_stop: Whether FMM indicates stopping
        """
        def add_boundary(mat, value=1):
            """Add boundary padding to matrix."""
            h, w = mat.shape
            new_mat = np.zeros((h + 2, w + 2)) + value
            new_mat[1:h + 1, 1:w + 1] = mat
            return new_mat

        # Compute traversable area
        traversible = skimage.morphology.binary_dilation(
            grid, self.selem) != True

        # Apply collision and visited maps
        traversible[self.collision_map == 1] = 0
        traversible[self.visited == 1] = 1

        # Ensure start position is traversable
        traversible[int(start[0]) - 1:int(start[0]) + 2,
                    int(start[1]) - 1:int(start[1]) + 2] = 1

        # Add boundary padding
        traversible = add_boundary(traversible)
        goal = add_boundary(goal, value=0)

        # Initialize FMM planner
        planner = FMMPlanner(traversible)

        # Expand goal area
        selem = skimage.morphology.disk(10)
        goal = skimage.morphology.binary_dilation(goal, selem) != True
        goal = 1 - goal * 1.
        planner.set_multi_goal(goal)

        # Get short-term goal
        state = [start[0] + 1, start[1] + 1]
        stg_x, stg_y, _, fmm_stop = planner.get_short_term_goal(state)

        stg_x, stg_y = stg_x - 1, stg_y - 1

        return (stg_x, stg_y), fmm_stop

    def _preprocess_obs(self, obs, use_seg=True):
        """
        Preprocess observations including RGB, depth and semantic segmentation.

        Args:
            obs: Raw observation array
            use_seg: Whether to use semantic segmentation

        Returns:
            np.ndarray: Preprocessed state array
        """
        args = self.args
        obs = obs.transpose(1, 2, 0)  # Convert to HWC format
        rgb = obs[:, :, :3]
        depth = obs[:, :, 3:4]

        # Handle different goal types with appropriate segmentation
        if self.info['goal_name'] == "potted plant":
            sem_seg_pred = self._get_MaskRcnn_pred(rgb.astype(np.uint8))
        else:
            # Use YOLO segmentation with appropriate thresholds
            if self.info['goal_name'] == "tv_monitor":
                plotted_rgb, mask_dict = self.yolo_seg.seg_image(
                    rgb.astype(np.uint8),
                    conf_threshold=self.tv_yolo_conf_thresh,
                    mask_threshold=self.tv_yolo_mask_thresh,
                    binary_threshold=self.tv_yolo_binary_thresh
                )
            else:
                plotted_rgb, mask_dict = self.yolo_seg.seg_image(
                    rgb.astype(np.uint8))

            self.rgb_vis_with_bbox = cv2.cvtColor(
                plotted_rgb, cv2.COLOR_RGB2BGR)

            # Filter TV detections by black pixel content
            mask_dict = self._filter_tv_by_black_pixels(
                rgb, mask_dict, self.tv_black_pixel_threshold)

            sem_seg_pred = self._convert_masks_to_channels(
                mask_dict).astype(np.float32)

        # Preprocess depth
        depth = self._preprocess_depth(
            depth, args.min_depth, args.max_depth)

        # Downsample if needed
        ds = args.env_frame_width // args.frame_width
        if ds != 1:
            rgb = np.asarray(self.res(rgb.astype(np.uint8)))
            depth = depth[ds // 2::ds, ds // 2::ds]
            sem_seg_pred = sem_seg_pred[ds // 2::ds, ds // 2::ds]

        # Concatenate all channels
        depth = np.expand_dims(depth, axis=2)
        state = np.concatenate((rgb, depth, sem_seg_pred),
                               axis=2).transpose(2, 0, 1)

        return state

    def _preprocess_depth(self, depth, min_d, max_d):
        """
        Preprocess depth observations.

        Args:
            depth: Raw depth array
            min_d: Minimum depth value
            max_d: Maximum depth value

        Returns:
            np.ndarray: Processed depth array
        """
        depth = depth[:, :, 0] * 1

        # Handle invalid depth values
        for i in range(depth.shape[1]):
            depth[:, i][depth[:, i] == 0.] = depth[:, i].max()

        # Filter extreme values
        mask2 = depth > 0.99
        depth[mask2] = 0.

        mask1 = depth == 0
        depth[mask1] = 100.0
        
        # Scale depth values
        depth = min_d * 100.0 + depth * (max_d-min_d) * 100.0

        return depth

    def _get_MaskRcnn_pred(self, rgb, use_seg=True):
        """
        Get semantic predictions using Mask R-CNN.

        Args:
            rgb: RGB image array
            use_seg: Whether to use segmentation

        Returns:
            np.ndarray: Semantic prediction array
        """
        self.rgb_vis_without_bbox = rgb[:, :, ::-1]
        if use_seg:
            semantic_pred, self.rgb_vis_with_bbox = self.MaskRcnn.get_prediction(
                rgb)
            semantic_pred = semantic_pred.astype(np.float32)
        else:
            semantic_pred = np.zeros((rgb.shape[0], rgb.shape[1], 16))

        return semantic_pred

    def _filter_tv_by_black_pixels(self, rgb_image, mask_dict, black_pixel_threshold=0.7):
        """
        Filter TV detections based on black pixel content.

        Args:
            rgb_image: RGB image array
            mask_dict: Dictionary of detection masks
            black_pixel_threshold: Threshold for black pixel ratio

        Returns:
            dict: Filtered mask dictionary
        """
        if self.info['goal_name'] == "tv" and "tv" in mask_dict:
            tv_mask = mask_dict["tv"]
            tv_pixels = rgb_image[tv_mask == 1]

            if len(tv_pixels) > 0:
                black_pixels = np.all(tv_pixels == [0, 0, 0], axis=1)
                black_pixel_ratio = np.sum(black_pixels) / len(tv_pixels)

                if black_pixel_ratio > black_pixel_threshold:
                    filtered_mask_dict = {k: v for k,
                                          v in mask_dict.items() if k != "tv"}
                    return filtered_mask_dict

        return mask_dict

    def _convert_masks_to_channels(self, mask_dict, default_shape=None):
        """
        Convert mask dictionary to multi-channel numpy array.

        Args:
            mask_dict: Dictionary of class masks
            default_shape: Default output shape if mask_dict is empty

        Returns:
            np.ndarray: Multi-channel array (H,W,num_categories)
        """
        if mask_dict:
            first_mask = next(iter(mask_dict.values()))
            H, W = first_mask.shape
        else:
            H, W = default_shape if default_shape else (
                self.args.env_frame_height, self.args.env_frame_width)

        # Initialize output array
        results = np.zeros(
            (H, W, self.args.num_sem_categories), dtype=np.uint8)

        # Fill channels based on category mapping
        for category, mask in mask_dict.items():
            if category in categories2channels:
                channel_idx = categories2channels[category]
                results[:, :, channel_idx] = mask

        return results

    def _visualize(self, inputs):
        """
        Generate visualization of current state.

        Args:
            inputs: Planning inputs dictionary
        """
        args = self.args
        
        # Set up directories and paths
        dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)
        
        # Get scene name from environment
        try:
            scene_path = self.habitat_env.sim.config.SCENE
            scene_name = scene_path.split("/")[-1].split(".")[0]
        except AttributeError:
            try:
                scene_id = self.habitat_env.current_episode.scene_id
                scene_name = scene_id.split("/")[-1].split(".")[0] if "/" in scene_id else scene_id
            except:
                scene_name = "unknown_scene"

        # Use scene-specific episode count
        scene_episode_no = self.scene_episode_count

        # Prepare visualization components
        map_pred = inputs['map_pred']
        exp_pred = inputs['exp_pred']
        start_x, start_y, start_o = inputs['pose_pred'][:3]
        goal = inputs['goal']
        sem_map = inputs['sem_map_pred']

        # Process semantic map visualization
        sem_map += 5
        map_mask = map_pred > 0
        no_cat_mask = sem_map == self.args.num_sem_categories + 4
        exp_mask = exp_pred > 0
        obs_exp = np.stack((map_mask, exp_mask), axis=0)
        
        # Save observation expansion data
        obs_exp_dir = '{}/obs_exp/'.format(dump_dir)
        os.makedirs(obs_exp_dir, exist_ok=True)
        obs_exp_filename = '{}/{}-{}-Obs-Exp-{}.npy'.format(
            obs_exp_dir, self.rank, scene_episode_no, self.timestep)
        np.save(obs_exp_filename, obs_exp)

        # Create trajectory visualization
        trajectory_mask = np.zeros_like(map_pred, dtype=bool)
        if hasattr(self, 'trajectory_positions') and len(self.trajectory_positions) > 0:
            for pos in self.trajectory_positions:
                pixel_x = int(pos[1] * 100.0 / args.map_resolution)
                pixel_y = int(pos[0] * 100.0 / args.map_resolution)
                if 0 <= pixel_y < map_pred.shape[0] and 0 <= pixel_x < map_pred.shape[1]:
                    trajectory_mask[pixel_y, pixel_x] = True

        # Mark current position
        current_pos_mask = np.zeros_like(map_pred, dtype=bool)
        current_pixel_x = int(start_x * 100.0 / args.map_resolution)
        current_pixel_y = int(start_y * 100.0 / args.map_resolution)
        if 0 <= current_pixel_y < map_pred.shape[0] and 0 <= current_pixel_x < map_pred.shape[1]:
            current_pos_mask[current_pixel_y, current_pixel_x] = True
            current_pos = (start_y, start_x, start_o)
            if not hasattr(self, 'trajectory_positions'):
                self.trajectory_positions = []
            if len(self.trajectory_positions) == 0 or \
               abs(self.trajectory_positions[-1][0] - current_pos[0]) > 0.1 or \
               abs(self.trajectory_positions[-1][1] - current_pos[1]) > 0.1:
                self.trajectory_positions.append(current_pos)

        # Get visited area visualization
        if hasattr(self, 'visited_vis') and self.visited_vis.shape == map_pred.shape:
            vis_mask = self.visited_vis == 1
        else:
            vis_mask = np.zeros_like(map_pred, dtype=bool)

        # Update semantic map visualization
        sem_map[no_cat_mask] = 0
        m1 = np.logical_and(no_cat_mask, exp_mask)
        sem_map[m1] = 2
        m2 = np.logical_and(no_cat_mask, map_mask)
        sem_map[m2] = 1
        sem_map[vis_mask] = 3

        # Process candidate goals if present
        if 'candidate_goals' in inputs and inputs['candidate_goals'] is not None:
            candidate_goals = inputs['candidate_goals']
            for i in range(candidate_goals.shape[0]):
                candidate_goal = candidate_goals[i]
                candidate_goal_mask = candidate_goal == 1
                sem_map[candidate_goal_mask] = 19

        # Process final goal visualization
        selem = skimage.morphology.disk(4)
        goal_mat = 1 - skimage.morphology.binary_dilation(goal, selem) != True
        goal_mask = goal_mat == 1
        sem_map[goal_mask] = 4

        # Create color palette visualization
        color_pal = [int(x * 255.) for x in color_palette]
        sem_map_vis = Image.new("P", (sem_map.shape[1], sem_map.shape[0]))
        sem_map_vis.putpalette(color_pal)
        sem_map_vis.putdata(sem_map.flatten().astype(np.uint8))
        sem_map_vis = sem_map_vis.convert("RGB")
        sem_map_vis = np.flipud(sem_map_vis)
        sem_map_vis = sem_map_vis[:, :, [2, 1, 0]]

        # Calculate layout and resize images
        layout = vu.calculate_vis_layout(
            self.vis_rgb_width, self.vis_rgb_height, self.vis_sem_map_size)

        if self.rgb_vis_with_bbox.shape[:2] != (self.vis_rgb_height, self.vis_rgb_width):
            rgb_resized = cv2.resize(self.rgb_vis, (self.vis_rgb_width, self.vis_rgb_height),
                                     interpolation=cv2.INTER_NEAREST)
        else:
            rgb_resized = self.rgb_vis_with_bbox

        if sem_map_vis.shape[:2] != (self.vis_sem_map_size, self.vis_sem_map_size):
            sem_map_vis = cv2.resize(sem_map_vis, (self.vis_sem_map_size, self.vis_sem_map_size),
                                     interpolation=cv2.INTER_NEAREST)

        # Place images in visualization
        rgb_y1, rgb_y2, rgb_x1, rgb_x2 = layout['rgb_region']
        sem_y1, sem_y2, sem_x1, sem_x2 = layout['sem_region']
        self.vis_image[rgb_y1:rgb_y2, rgb_x1:rgb_x2] = rgb_resized
        self.vis_image[sem_y1:sem_y2, sem_x1:sem_x2] = sem_map_vis

        # Draw agent position indicator
        r, c = start_y, start_x
        start_pixel = [int(r * 100.0 / args.map_resolution),
                       int(c * 100.0 / args.map_resolution)]
        start_pixel = pu.threshold_poses(start_pixel, map_pred.shape)
        pos = (
            start_pixel[1] * self.vis_sem_map_size / map_pred.shape[1],
            (map_pred.shape[0] - start_pixel[0]) * self.vis_sem_map_size / map_pred.shape[0],
            np.deg2rad(-start_o)
        )
        agent_arrow = vu.get_contour_points(
            pos, origin=layout['arrow_origin'], size=12)
        color = (int(color_palette[11] * 255),
                 int(color_palette[10] * 255),
                 int(color_palette[9] * 255))
        cv2.drawContours(self.vis_image, [agent_arrow], 0, color, -1)

        # Save visualization if enabled
        if args.print_images:
            vis_dir = '{}/episodes/process_{}/{}/episode_{}_{}/visualization/'.format(
                dump_dir, self.rank, scene_name, scene_episode_no, self.info['goal_name'].split()[-1])
            os.makedirs(vis_dir, exist_ok=True)
            fn = '{}/{}-{}-Vis-{}.png'.format(
                vis_dir, self.rank, scene_episode_no, self.timestep)
            cv2.imwrite(fn, self.vis_image)

    def _save_tmp_files(self, rgbd, candidate_goal_masks, inputs):
        """
        Save temporary files for debugging and visualization.

        Args:
            rgbd: RGBD observation array
            candidate_goal_masks: Candidate goal masks
            inputs: Planning inputs dictionary
        """
        args = self.args
        dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)

        # Get scene name
        try:
            scene_path = self.habitat_env.sim.config.SCENE
            scene_name = scene_path.split("/")[-1].split(".")[0]
        except AttributeError:
            try:
                scene_id = self.habitat_env.current_episode.scene_id
                scene_name = scene_id.split("/")[-1].split(".")[0] if "/" in scene_id else scene_id
            except:
                scene_name = "unknown_scene"

        # Use scene-specific episode count
        scene_episode_no = self.scene_episode_count

        # Create directory structure
        base_ep_dir = '{}/episodes/process_{}/{}/episode_{}_{}'.format(
            dump_dir, self.rank, scene_name, scene_episode_no, self.info['goal_name'].split()[-1])
        rgb_vis_dir = '{}/rgb_vis/'.format(base_ep_dir)
        rgbd_dir = '{}/rgbd/'.format(base_ep_dir)
        pos_on_rgb_dir = '{}/pos_on_rgb/'.format(base_ep_dir)
        percent_goals_dir = '{}/percent_goals/'.format(base_ep_dir)
        os.makedirs(rgb_vis_dir, exist_ok=True)
        os.makedirs(rgbd_dir, exist_ok=True)
        os.makedirs(pos_on_rgb_dir, exist_ok=True)
        os.makedirs(percent_goals_dir, exist_ok=True)

        # Save RGB visualization
        if hasattr(self, 'rgb_vis_without_bbox') and self.rgb_vis_without_bbox is not None:
            rgb_filename = '{}/{}-{}-RGB-{}.png'.format(
                rgb_vis_dir, self.rank, scene_episode_no, self.timestep)
            cv2.imwrite(rgb_filename, self.rgb_vis_without_bbox)

        # Save RGBD data
        rgbd_filename = '{}/{}-{}-RGBD-{}.npy'.format(
            rgbd_dir, self.rank, scene_episode_no, self.timestep)
        pos_on_rgb_filename = '{}/{}-{}-Pos-On-RGB-{}.png'.format(
            pos_on_rgb_dir, self.rank, scene_episode_no, self.timestep)
        percent_goals_filename = '{}/{}-{}-Percent-Goals-{}.png'.format(
            percent_goals_dir, self.rank, scene_episode_no, self.timestep)
        np.save(rgbd_filename, self.rgbd)
        self.annotate_pos_and_save_rgb(
            rgbd[:, :, :3], candidate_goal_masks, pos_on_rgb_filename)

    def annotate_pos_and_save_rgb(self, rgb_image, candidate_goal_masks, output_path=None):
        """
        Annotate candidate goal positions on RGB image and save.

        Args:
            rgb_image: RGB image array
            candidate_goal_masks: Candidate goal masks
            output_path: Output file path
        """
        if candidate_goal_masks is None or candidate_goal_masks.shape[0] == 0:
            return
            
        annotated_image = rgb_image.copy()
        
        # Annotate each candidate goal
        for idx in range(candidate_goal_masks.shape[0]):
            mask = candidate_goal_masks[idx]
            y, x = np.argwhere(mask == 1)[0]
            cv2.putText(
                img=annotated_image,
                text=str(idx + 1),
                org=(x, y),
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=0.8,
                color=(0, 0, 255),
                thickness=2,
                lineType=cv2.LINE_AA
            )

        # Save annotated image
        cv2.imwrite(output_path, cv2.cvtColor(
            annotated_image, cv2.COLOR_RGB2BGR))

    def visualize_percent_goals(self, filled_mask, candidate_goals, obstacle_map, k, threshold, output_path=None):
        """
        Visualize goal selection percentages on exploration map.

        Args:
            filled_mask: Filled exploration area mask
            candidate_goals: Candidate goal masks
            obstacle_map: Obstacle map
            k: Radius for percentage calculation
            threshold: Percentage threshold for coloring
            output_path: Output file path
        """
        if filled_mask is None or candidate_goals is None or obstacle_map is None:
            return
            
        # Create base visualization image
        h, w = filled_mask.shape
        base_img = np.zeros((h, w, 3), dtype=np.uint8) + 255

        # Mark different area types
        obstacle_map = obstacle_map.cpu().numpy()
        gray_areas = np.where((filled_mask == 1) & (obstacle_map == 0))
        dark_gray_areas = np.where((filled_mask == 1) & (obstacle_map == 1))
        base_img[gray_areas[0], gray_areas[1]] = [200, 200, 200]
        base_img[dark_gray_areas[0], dark_gray_areas[1]] = [100, 100, 100]

        # Calculate percentages for each candidate goal
        percentages = []
        positions = []
        for i in range(candidate_goals.shape[0]):
            y, x = np.argwhere(candidate_goals[i] == 1)[0]
            positions.append((x, y))

            # Calculate percentage of unexplored area in radius k
            x_min, x_max = max(0, x-k), min(w, x+k+1)
            y_min, y_max = max(0, y-k), min(h, y+k+1)
            region = filled_mask[y_min:y_max, x_min:x_max]
            y_grid, x_grid = np.ogrid[y_min:y_max, x_min:x_max]
            mask = ((x_grid - x)**2 + (y_grid - y)**2) <= k**2
            zero_pixels = np.sum((region == 0) & mask)
            total_pixels = np.sum(mask)
            percent = (zero_pixels / total_pixels) * 100 if total_pixels > 0 else 0
            percentages.append(percent)

        # Draw annotations with collision avoidance
        text_positions = []
        for i, (x, y) in enumerate(positions):
            # Draw goal point and radius circle
            cv2.circle(base_img, (x, y), 2, (0, 0, 255), -1)
            cv2.circle(base_img, (x, y), k, (0, 0, 255), 1)

            # Determine text position avoiding overlaps
            offset_y = -15
            for pos in text_positions:
                if abs(x - pos[0]) < 40 and abs(y - pos[1]) < 20:
                    offset_y = 15
                    break

            text_pos = (x - 15, y + offset_y)
            text_positions.append(text_pos)

            # Set text color based on threshold
            text_color = (0, 255, 0) if percentages[i] < threshold else (255, 0, 0)
            text = f"{percentages[i]:.1f}%"
            cv2.putText(base_img, text, text_pos,
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, text_color, 1)

        # Save visualization
        cv2.imwrite(output_path, base_img)

    def _save_trajectory_mask(self, map_shape):
        """
        Save agent trajectory mask to file.

        Args:
            map_shape: Map dimensions
        """
        args = self.args
        
        # Initialize trajectory mask
        trajectory_mask = np.zeros(map_shape, dtype=np.uint8)
        if hasattr(self, 'visited_vis') and self.visited_vis is not None:
            trajectory_mask = (self.visited_vis > 0).astype(np.uint8)

        # Set up save directory
        dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)
        try:
            scene_path = self.habitat_env.sim.config.SCENE
            scene_name = scene_path.split("/")[-1].split(".")[0]
        except AttributeError:
            try:
                scene_id = self.habitat_env.current_episode.scene_id
                scene_name = scene_id.split("/")[-1].split(".")[0] if "/" in scene_id else scene_id
            except:
                scene_name = "unknown_scene"
        scene_episode_no = self.scene_episode_count

        # Create save path
        base_ep_dir = '{}/episodes/process_{}/{}/episode_{}_{}'.format(
            dump_dir, self.rank, scene_name, scene_episode_no, self.info['goal_name'].split()[-1])
        trajectory_masks_dir = '{}/trajectory_masks/'.format(base_ep_dir)
        os.makedirs(trajectory_masks_dir, exist_ok=True)

        # Save trajectory data
        trajectory_mask_filename = '{}/{}-{}-Trajectory-Mask-{}.npy'.format(
            trajectory_masks_dir, self.rank, scene_episode_no, self.timestep)
        np.save(trajectory_mask_filename, trajectory_mask)

        # Optionally save visualization
        if args.print_images:
            trajectory_vis_filename = '{}/{}-{}-Trajectory-Vis-{}.png'.format(
                trajectory_masks_dir, self.rank, scene_episode_no, self.timestep)
            trajectory_vis = trajectory_mask * 255
            cv2.imwrite(trajectory_vis_filename, trajectory_vis)