from collections import deque
import random
from matplotlib import cm
import math
from math import sqrt, pi
from datetime import datetime
import re
import torch
import os
import numpy as np
from PIL import Image
import torch.nn as nn
from scipy import ndimage
import matplotlib.pyplot as plt
from model import Semantic_Mapping
import cv2
import logging
from pathlib import Path
import time
from collections import deque, defaultdict
import json
from scipy.ndimage import binary_fill_holes, binary_dilation
from scipy.ndimage import label
import torch
import numpy as np
from math import sqrt, pi


def get_local_map_boundaries(agent_loc, local_sizes, full_sizes, global_downscaling):
    loc_r, loc_c = agent_loc
    local_w, local_h = local_sizes
    full_w, full_h = full_sizes
    if global_downscaling > 1:
        gx1, gy1 = loc_r - local_w // 2, loc_c - local_h // 2
        gx2, gy2 = gx1 + local_w, gy1 + local_h
        if gx1 < 0:
            gx1, gx2 = 0, local_w
        if gx2 > full_w:
            gx1, gx2 = full_w - local_w, full_w
        if gy1 < 0:
            gy1, gy2 = 0, local_h
        if gy2 > full_h:
            gy1, gy2 = full_h - local_h, full_h
    else:
        gx1, gx2, gy1, gy2 = 0, full_w, 0, full_h
    return [gx1, gx2, gy1, gy2]


def init_map_and_pose(full_map, full_pose, map_size_cm, origins, planner_pose_inputs,
                      map_resolution, num_scenes, lmb, local_map, local_pose,
                      device, local_w, local_h, full_w, full_h, global_downscaling):
    full_map.fill_(0.)
    full_pose.fill_(0.)
    full_pose[:, :2] = map_size_cm / 100.0 / 2.0
    locs = full_pose.cpu().numpy()  # locations
    planner_pose_inputs[:, :3] = locs
    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [int(r * 100.0 / map_resolution),
                        int(c * 100.0 / map_resolution)]
        full_map[e, 2:4, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0
        lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                          (local_w, local_h),
                                          (full_w, full_h), global_downscaling)
        planner_pose_inputs[e, 3:] = lmb[e]
        origins[e] = [lmb[e][2] * map_resolution / 100.0,
                      lmb[e][0] * map_resolution / 100.0, 0.]
        # This line extracts the local area centered on the agent's current position from full_map and assigns it to local_map. Thus,
        # the content of local_map is updated with the latest information around the agent, but its size remains unchanged.
    for e in range(num_scenes):
        local_map[e] = full_map[e, :,
                                lmb[e, 0]:lmb[e, 1],
                                lmb[e, 2]:lmb[e, 3]]
        local_pose[e] = full_pose[e] - \
            torch.from_numpy(origins[e]).to(device).float()
    return full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose


def update_full_map_for_env(full_map, local_map, lmb, e):
    full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
        local_map[e]
    return full_map


def init_map_and_pose_for_env(full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose, map_size_cm,
                              map_resolution, local_w, local_h, full_w, full_h, global_downscaling, device, lmb, e):
    full_map[e].fill_(0.)
    full_pose[e].fill_(0.)
    full_pose[e, :2] = map_size_cm / 100.0 / 2.0
    locs = full_pose[e].cpu().numpy()
    planner_pose_inputs[e, :3] = locs
    r, c = locs[1], locs[0]
    loc_r, loc_c = [int(r * 100.0 / map_resolution),
                    int(c * 100.0 / map_resolution)]
    full_map[e, 2:4, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0
    lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                      (local_w, local_h),
                                      (full_w, full_h), global_downscaling)
    planner_pose_inputs[e, 3:] = lmb[e]
    origins[e] = [lmb[e][2] * map_resolution / 100.0,
                  lmb[e][0] * map_resolution / 100.0, 0.]
    local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
    local_pose[e] = full_pose[e] - \
        torch.from_numpy(origins[e]).to(device).float()
    return full_map, full_pose, planner_pose_inputs, origins, lmb, local_map, local_pose


def save_image_batch(batch_array, save_dir="policy_vis", img_format='png'):
    """
    Save a 4D image batch array as image files.
    Parameters:
        batch_array: numpy.ndarray (shape [batch, height, width, channels])
        save_dir: target directory path (string)
        img_format: image format (default png, supports jpg/png/bmp, etc.)
    """
    # Create target directory
    os.makedirs(save_dir, exist_ok=True)
    try:
        # Iterate over each image in the batch
        for i in range(len(batch_array)):
            img_data = batch_array[i]
            # Validate image data range
            if img_data.dtype != np.uint8:
                if img_data.max() <= 1.0:  # Handle normalized data
                    img_data = (img_data * 255).astype(np.uint8)
                else:
                    img_data = img_data.astype(np.uint8)
            # Create PIL image object (handle potential channel order issues)
            pil_image = Image.fromarray(img_data, mode='RGB')
            # Construct save path
            filename = f"image_{i:03d}.{img_format}"
            save_path = os.path.join(save_dir, filename)
            # Save image (automatically select format based on extension)
            pil_image.save(save_path)
            print(f"Successfully saved: {save_path}")
    except Exception as e:
        print(f"Save failed, error type: {type(e).__name__}, details: {str(e)}")
        raise


def prepare_planner_inputs(full_map, goal_maps, candidate_goal_map, candidate_goal_masks, filled_mask, obstacle_map,
                           planner_pose_inputs,
                           num_scenes, found_goal, visualize, print_images,
                           is_rotatation=[False, None], policy_vis: list = [False, []]):
    """
    Prepare planner inputs based on the global map.
    Args:
        full_map: global map (num_scenes, 20, 480, 480)
        goal_maps: list of global goal maps
        planner_pose_inputs: pose inputs (num_scenes, 7) - only first 3 dimensions are used
        other parameters remain unchanged
    """
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        # Use global map
        p_input['map_pred'] = full_map[e, 0, :, :].cpu().numpy()
        p_input['exp_pred'] = full_map[e, 1, :, :].cpu().numpy()
        # Only pass the first 3 dimensions of pose information (x, y, orientation in global coordinate system)
        p_input['pose_pred'] = planner_pose_inputs[e][:3]
        p_input['candidate_goals'] = candidate_goal_map
        p_input['candidate_goal_masks'] = candidate_goal_masks
        p_input['filled_mask'] = filled_mask
        p_input['obstacle_map'] = obstacle_map
        if policy_vis[0]:
            p_input['rgb_vis'] = policy_vis[1]
        else:
            p_input['rgb_vis'] = None
        # # Use global goal map
        # if isinstance(goal_maps, list) and len(goal_maps) > 0:
        #     if hasattr(goal_maps[0], 'sum') and goal_maps[0].sum() == 0:
        #         # If random goal is needed, set it on the global map
        #         # 480, 480
        #         # full_w, full_h = full_map.shape[2], full_map.shape[3]
        #         # goal_maps = set_random_goal_map(
        #         #     num_scenes, goal_maps, full_w, full_h, e)
        #         p_input['goal'] = goal_maps[e]
        #     else:
        #         p_input['goal'] = goal_maps[e]
        # else:
        p_input['goal'] = goal_maps[e]
        p_input['found_goal'] = found_goal[e]
        p_input['rotate_agent'] = is_rotatation
        if visualize or print_images:
            full_map[e, -1, :, :] = 1e-5
            p_input['sem_map_pred'] = full_map[e,
                                               4:, :, :].argmax(0).cpu().numpy()
    return planner_inputs, full_map


def update_full_and_local_map(full_map, local_map, lmb, planner_pose_inputs, full_pose, local_pose, origins, num_scenes,
                              map_resolution, global_downscaling, local_w, local_h, full_w, full_h, device):
    for e in range(num_scenes):
        full_map = update_full_map_for_env(full_map, local_map, lmb, e)
        full_pose[e] = local_pose[e] + \
            torch.from_numpy(origins[e]).to(device).float()
        locs = full_pose[e].cpu().numpy()
        r, c = locs[1], locs[0]
        loc_r, loc_c = [int(r * 100.0 / map_resolution),
                        int(c * 100.0 / map_resolution)]
        lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                          (local_w, local_h),
                                          (full_w, full_h), global_downscaling)
        planner_pose_inputs[e, 3:] = lmb[e]
        origins[e] = [lmb[e][2] * map_resolution / 100.0,
                      lmb[e][0] * map_resolution / 100.0, 0.]
        local_map[e] = full_map[e, :,
                                lmb[e, 0]:lmb[e, 1],
                                lmb[e, 2]:lmb[e, 3]]
        local_pose[e] = full_pose[e] - \
            torch.from_numpy(origins[e]).to(device).float()
    return full_map, local_map, lmb, planner_pose_inputs, local_pose


def reset_current_loction(local_pose, planner_pose_inputs, origins, local_map, num_scenes, map_resolution):
    locs = local_pose.cpu().numpy()
    planner_pose_inputs[:, :3] = locs + origins
    local_map[:, 2, :, :].fill_(0.)
    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [int(r * 100.0 / map_resolution),
                        int(c * 100.0 / map_resolution)]
        local_map[e, 2:4, loc_r - 1:loc_r + 2,
                  loc_c - 1:loc_c + 2] = 1.
    return planner_pose_inputs, local_map


def whether_seen_the_goal_and_set_goal_map(num_scenes, infos, full_map, goal_maps, found_goal):
    """
    Detect and set goal map based on the global map.
    """
    # Create goal map using the size of the global map
    goal_maps = [np.zeros((full_map.shape[2], full_map.shape[3]))
                 for _ in range(num_scenes)]
    for e in range(num_scenes):
        cn = infos[e]['goal_cat_id'] + 4
        if full_map[e, cn, :, :].sum() != 0.:
            cat_semantic_map = full_map[e, cn, :, :].cpu().numpy()
            cat_semantic_scores = cat_semantic_map
            cat_semantic_scores[cat_semantic_scores > 0] = 1.
            # Find all connected regions
            labeled_array, num_features = label(cat_semantic_scores)
            if num_features > 0:
                # Calculate the area of each region
                region_areas = []
                for i in range(1, num_features + 1):
                    area = np.sum(labeled_array == i)
                    region_areas.append(area)
                # Find the region with the largest area
                max_area_idx = np.argmax(region_areas) + 1
                # Keep only the region with the largest area
                goal_maps[e] = (labeled_array == max_area_idx).astype(float)
            else:
                goal_maps[e] = cat_semantic_scores
            found_goal[e] = 1
        # if full_map[e, cn, :, :].sum() != 0.:  # seen the goal category
        #     binary_map = full_map[e, cn, :, :].cpu().numpy()
        #     y_coords, x_coords = np.where(binary_map > 0)
        #     if len(y_coords) > 0:
        #         center_y = int(np.round(np.mean(y_coords)))
        #         center_x = int(np.round(np.mean(x_coords)))
        #         goal_maps[e] = np.zeros_like(binary_map)
        #         goal_maps[e][center_y, center_x] = 1
        #         found_goal[e] = 1
    return goal_maps, found_goal


def generate_candidate_goal_map(full_map, num_candidate_goals, start_channel=4):
    candidate_goal_map = np.zeros(
        (num_candidate_goals, full_map.shape[2], full_map.shape[3]))
    for c in range(num_candidate_goals):
        cur_channel = start_channel + c
        if full_map[0, cur_channel, :, :].sum() != 0.:
            binary_map = full_map[0, cur_channel, :, :].cpu().numpy()
            y_coords, x_coords = np.where(binary_map > 0)
            if len(y_coords) > 0:
                center_y = int(np.round(np.mean(y_coords)))
                center_x = int(np.round(np.mean(x_coords)))
                candidate_goal_map[c] = np.zeros_like(binary_map)
                candidate_goal_map[c][center_y, center_x] = 1
    return candidate_goal_map


def set_the_selected_goal_map(num_scenes, full_map):
    """
    Set the selected goal map based on the global map.
    """
    goal_maps = [np.zeros((full_map.shape[2], full_map.shape[3]))
                 for _ in range(num_scenes)]
    for e in range(num_scenes):
        cn = 14
        if full_map[e, cn, :, :].sum() != 0.:  # the next selected goal point
            binary_map = full_map[e, cn, :, :].cpu().numpy()
            y_coords, x_coords = np.where(binary_map > 0)
            if len(y_coords) > 0:
                center_y = int(np.round(np.mean(y_coords)))
                center_x = int(np.round(np.mean(x_coords)))
                goal_maps[e] = np.zeros_like(binary_map)
                goal_maps[e][center_y, center_x] = 1
    return goal_maps


def set_random_goal_map(num_scenes, goal_maps, local_w, local_h, e=0):
    # random goal
    g_action = torch.randn(num_scenes, 2) * num_scenes
    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w),
                     int(action[1] * local_h)]
                    for action in cpu_actions]
    global_goals = [[min(x, int(local_w - 1)),
                     min(y, int(local_h - 1))]
                    for x, y in global_goals]
    goal_maps[e][global_goals[e][0], global_goals[e][1]] = 1
    return goal_maps


def update_semantic_map(infos, obs, local_map, local_pose, sem_map_module, num_scenes, device):
    poses = torch.from_numpy(np.asarray(
        [infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)])).float().to(device)
    # update the semantic map
    _, local_map, _, local_pose = sem_map_module(
        obs, poses, local_map, local_pose)
    return local_map, local_pose


def expand_semantic_mask(obs, target_pixels=200, start_channel=4, end_channel=9):
    """
    Expand small target regions in the semantic segmentation mask to a specified number of pixels.
    Args:
        obs: tensor of shape [num_scenes, 20, 120, 160],
             first 4 channels are RGB-D, last 16 channels are semantic segmentation results
        target_pixels: target number of pixels, default is 100
        start_channel: start channel index
        end_channel: end channel index
    Returns:
        processed tensor, keeping the same shape as input
    """
    # Create output tensor, initialized with the same values as input
    output = obs.clone()
    # Only process semantic channels (4-19)
    for scene_idx in range(obs.shape[0]):  # iterate over each scene
        for channel_idx in range(start_channel, end_channel + 1):  # iterate over semantic channels
            # Get current mask and convert to numpy array
            current_mask = obs[scene_idx, channel_idx].cpu().numpy()
            # Calculate number of 1s
            num_ones = np.sum(current_mask)
            # If there are no 1s or the number of 1s is already greater than or equal to the target, skip
            if num_ones == 0 or num_ones >= target_pixels:
                continue
            # Find all connected regions of 1s
            labeled_array, num_features = ndimage.label(current_mask)
            # Process each connected region
            for region_idx in range(1, num_features + 1):
                # Get mask for current region
                region_mask = (labeled_array == region_idx)
                # Calculate number of 1s in current region
                region_ones = np.sum(region_mask)
                # Calculate number of pixels to add
                pixels_to_add = target_pixels - region_ones
                # Use distance transform to find the center of the region
                distance_map = ndimage.distance_transform_edt(region_mask)
                # Create expanded mask
                expanded_mask = distance_map > 0
                while np.sum(expanded_mask) < target_pixels:
                    # Gradually expand the region
                    expanded_mask = ndimage.binary_dilation(expanded_mask)
                # If expanded too much, slightly shrink
                while np.sum(expanded_mask) > target_pixels:
                    # Find edge pixels
                    edges = expanded_mask ^ ndimage.binary_erosion(
                        expanded_mask)
                    if np.sum(edges) == 0:
                        break
                    # Randomly remove some edge pixels
                    edge_pixels = np.where(edges)
                    remove_count = int(np.sum(expanded_mask) - target_pixels)
                    if len(edge_pixels[0]) <= remove_count:
                        break
                    remove_indices = np.random.choice(
                        len(edge_pixels[0]),
                        remove_count,
                        replace=False
                    )
                    expanded_mask[
                        edge_pixels[0][remove_indices],
                        edge_pixels[1][remove_indices]
                    ] = False
                # Update original mask
                current_mask = np.logical_or(current_mask, expanded_mask)
            # Convert processed mask back to tensor and update output
            output[scene_idx, channel_idx] = torch.from_numpy(
                current_mask).to(obs.device)
    return output


def whether_seen_the_goal(obs, infos, num_scenes):
    whether_seen = [False for _ in range(num_scenes)]
    for e in range(num_scenes):
        cn = infos[e]['goal_cat_id'] + 4
        if obs[e, cn, :, :].sum() != 0.:  # seen the goal category
            whether_seen[e] = True
        else:
            whether_seen[e] = False
    return whether_seen


def visualize_and_save(tensor_before, tensor_after, channel_idx=8):
    """Visualize and save results before and after expansion"""
    # Create save directory
    save_dir = "visualization_results"
    os.makedirs(save_dir, exist_ok=True)
    # Create a figure with two subplots
    plt.figure(figsize=(12, 6))
    # Plot mask before expansion
    plt.subplot(121)
    plt.imshow(tensor_before[0, channel_idx].cpu().numpy(), cmap='gray')
    plt.title('Before Expansion')
    plt.colorbar()
    # Plot mask after expansion
    plt.subplot(122)
    plt.imshow(tensor_after[0, channel_idx].cpu().numpy(), cmap='gray')
    plt.title('After Expansion')
    plt.colorbar()
    # Save image
    plt.savefig(os.path.join(save_dir, 'expansion_comparison_running.png'))
    plt.close()
    # Print statistics
    ones_before = torch.sum(tensor_before[0, channel_idx]).item()
    ones_after = torch.sum(tensor_after[0, channel_idx]).item()
    print(f"Number of 1s before expansion: {ones_before}")
    print(f"Number of 1s after expansion: {ones_after}")


def extract_number(input_str: str) -> int:
    """
    Extract the first number from the last pair of curly braces in the text.
    Parameters:
        input_str: input string, containing content like {...}
    Returns:
        int: extracted first positive integer
             returns -2 if no number is found
    """
    # 1. Find the last right curly brace from the end
    right_brace = input_str.rfind('}')
    if right_brace == -1:
        return -2
    # 2. Find the corresponding left curly brace from the right brace position backwards
    left_brace = input_str.rfind('{', 0, right_brace)
    if left_brace == -1:
        return -2
    # 3. Extract content within the braces and find the number
    bracket_content = input_str[left_brace:right_brace + 1]
    number_match = re.search(r'\d+', bracket_content)
    if not number_match:
        return -2
    # 4. Convert to integer and return
    return int(number_match.group(0))


def add_image_number(rgb, save_images=False):
    """
    Add a numbered circle with red border and white background to the top-left corner of the image.
    Parameters:
        images: list of numpy arrays, each with shape (480,640,4)
        save_images: whether to save processed images locally
    Returns:
        list of processed numpy arrays
    """
    processed_images = []
    # Parameter configuration
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1.2   # font size
    thickness = 2      # font thickness
    circle_radius = 25  # circle radius
    margin = 15        # margin from border
    border_thickness = 3  # red border thickness
    if save_images:
        os.makedirs("numbered_images", exist_ok=True)
    for idx, img in enumerate(rgb):
        h, w = img.shape[:2]
        # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # convert color channel order
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        # Center coordinates (ensure it does not exceed boundaries)
        center_x = margin + circle_radius
        center_y = margin + circle_radius
        # Draw white filled circle
        cv2.circle(img,
                   (center_x, center_y),
                   circle_radius,
                   (255, 255, 255),  # white fill
                   -1)  # -1 means filled
        # Draw red border circle
        cv2.circle(img,
                   (center_x, center_y),
                   circle_radius,
                   (0, 0, 255),  # red border
                   border_thickness)
        # Calculate text position (centered)
        text = str(idx + 1)
        (text_w, text_h), baseline = cv2.getTextSize(
            text, font, font_scale, thickness)
        # Text origin coordinates (centering algorithm)
        text_x = center_x - text_w // 2
        text_y = center_y + text_h // 2
        # Draw text
        cv2.putText(img, text,
                    (text_x, text_y),
                    font,
                    font_scale,
                    (0, 0, 0),  # black text
                    thickness,
                    cv2.LINE_AA)
        img_copy = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        processed_images.append(img_copy)
        if save_images:
            save_path = f"numbered_images/image_{idx+1}.jpg"
            # cv2.imwrite(save_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            cv2.imwrite(save_path, img)
    return processed_images


def select_view_at_beginning(vlm, history_rgb, infos, logger, save_images=True) -> int:
    numeric_rotate_views = add_image_number(
        history_rgb, save_images=save_images)
    view_response = vlm.ask_multiple_images(
        numeric_rotate_views, infos[0]['goal_name'])
    log = " ".join(
        f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]")
    log += view_response
    log += ""
    log += "--"*40
    logger.info(log)
    view_num = extract_number(view_response)
    return view_num


def exp_goal_based_cur_view(rgbd, vlm, trav, infos, logger) -> torch.Tensor:
    # Check the return value of run_current_view
    success = trav.run_current_view(
        rgb=rgbd[:, :, :3], depth=rgbd[:, :, 3])
    if not success:
        # Handle case where no traversable area is found
        log = " ".join(
            f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]
")
        log += "WARNING: No traversable area found in current view. Skipping VLM processing."
        log += ""
        log += "--"*40
        logger.warning(log)
        return None, None
    rgb_annotated_np = trav.get_numpy_of_annotated_candidate_exp_goals()
    # Additional check if rgb_annotated_np is None
    if rgb_annotated_np is None:
        log = " ".join(
            f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]
")
        log += "ERROR: Failed to generate annotated image despite successful traversable area detection."
        log += ""
        log += "--"*40
        logger.error(log)
        return None, None
    action_response = vlm.ask_single_image(
        rgb_annotated_np, infos[0]['goal_name'])
    log = " ".join(
        f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]
")
    log += action_response
    log += ""
    log += "--"*40
    logger.info(log)
    action = extract_number(action_response)
    # action = 2
    action_highlighted_np = trav.highlight_selected_goal(
        selected_goal=action)
    save_image_batch([action_highlighted_np])  # save
    # Get the mask of the selected goal
    goal_mask = trav.get_goal_mask_of_select_point(action)
    return goal_mask, action_highlighted_np


def set_base(args):
    num_scenes = args.num_processes
    num_episodes = int(args.num_eval_episodes)
    device = args.device = torch.device("cuda" if args.cuda else "cpu")
    finished = np.zeros((args.num_processes))
    # Initialize map variables:
    # Full map consists of multiple channels containing the following:
    # 1. Obstacle Map
    # 2. Explored Area
    # 3. Current Agent Location
    # 4. Past Agent Locations
    # 5,6,7,.. : Semantic Categories
    nc = args.num_sem_categories + 4  # num channels number of semantic map channels
    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution  # 480
    full_w, full_h = map_size, map_size
    local_w = int(full_w / args.global_downscaling)  # 240
    local_h = int(full_h / args.global_downscaling)  # 240
    # Initializing full and local map
    full_map = torch.zeros(num_scenes, nc, full_w, full_h).float().to(
        device)  # full_map.shape=torch.Size([num_process, 20, 480, 480])
    local_map = torch.zeros(num_scenes, nc, local_w,
                            local_h).float().to(device)  # local_map.shape =torch.Size([num_process, 20, 240, 240])
    # Initial full and local pose
    # tensor([[0., 0., 0.],[0., 0., 0.]], device='cuda:0')
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    # tensor([[0., 0., 0.],[0., 0., 0.]], device='cuda:0')
    local_pose = torch.zeros(num_scenes, 3).float().to(device)
    # Origin of local map
    origins = np.zeros((num_scenes, 3))
    # Local Map Boundaries-->lmb
    lmb = np.zeros((num_scenes, 4)).astype(int)
    # Planner pose inputs has 7 dimensions
    # 1-3 store continuous global agent location
    # 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))
    return full_w, full_h, local_w, local_h, num_scenes, num_episodes, device, finished, full_map, local_map, full_pose, local_pose, origins, lmb, planner_pose_inputs


def set_more(args, num_scenes, full_w, full_h, num_episodes):
    device = args.device = torch.device("cuda" if args.cuda else "cpu")
    # Global policy observation space
    # ngc=24 number of global channel
    # Semantic Mapping
    # Semantic_Mapping return fp_map_pred, map_pred, pose_pred, current_poses
    sem_map_module = Semantic_Mapping(args).to(device)
    sem_map_module.eval()
    found_goal = [0 for _ in range(num_scenes)]
    goal_maps = [np.zeros((full_w, full_h))
                 for _ in range(num_scenes)]
    torch.set_grad_enabled(False)
    spl_per_category = defaultdict(list)
    success_per_category = defaultdict(list)
    start = time.time()
    dones = [False for _ in range(num_scenes)]
    # obs, infos, rgbd = envs.reset()
    # pre_obs_pose = [[] for _ in range(num_scenes)]
    policy_vis = [False, []]
    dones = [False for _ in range(num_scenes)]
    rotatation_over_at_beginning = [
        False for _ in range(num_episodes)]
    new_goal_required = [False for _ in range(num_scenes)]
    rotation_history = [{
        'rgbd': None,         # RGBD image of current view
        'local_pose': None,   # current local pose
        'obs': None,          # observation data
        'sensor_pose': None,  # sensor pose
    } for _ in range(12)]
    return rotation_history, rotatation_over_at_beginning, new_goal_required, \
        dones, policy_vis, dones, \
        sem_map_module, found_goal,  goal_maps, \
        spl_per_category, success_per_category, start


def set_doneInfo(args):
    if args.eval:
        episode_success = []
        episode_spl = []
        episode_dist = []
        for _ in range(args.num_processes):
            episode_success.append(deque(maxlen=args.num_eval_episodes))
            episode_spl.append(deque(maxlen=args.num_eval_episodes))
            episode_dist.append(deque(maxlen=args.num_eval_episodes))
    # else branch is commented out, as args.eval is always 1 in the usage scenario
    # else:
    #     episode_success = None
    #     episode_spl = None
    #     episode_dist = None
    spl_per_category = defaultdict(list)
    success_per_category = defaultdict(list)
    return episode_success, episode_spl, episode_dist, spl_per_category, success_per_category


def set_logger(args):
    # Setup Logging
    dump_dir = Path("{}/dump/{}/".format(args.dump_location, args.exp_name))
    log_dir = Path("{}/logs/{}/".format(args.dump_location, args.exp_name))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(dump_dir):
        os.makedirs(dump_dir)
    # Configure Logger
    logger = logging.getLogger('navigation')
    logger.setLevel(logging.INFO)
    handler = logging.FileHandler(log_dir / 'navigation.log')
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger


def log_interval(logger, infos, args, num_scenes, start, episode_success, episode_spl, episode_dist):
    step = infos[0]['time']
    end = time.time()
    time_elapsed = time.gmtime(end - start)
    log = " ".join([
        "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
        "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
        "num timesteps {},".format(step * num_scenes),
        "FPS {},".format(int(step * num_scenes / (end - start)))
    ])
    if args.eval:
        total_success = []
        total_spl = []
        total_dist = []
        for e in range(args.num_processes):
            for acc in episode_success[e]:
                total_success.append(acc)
            for dist in episode_dist[e]:
                total_dist.append(dist)
            for spl in episode_spl[e]:
                total_spl.append(spl)
        if len(total_spl) > 0:
            log += " ObjectNav succ/spl/dtg:"
            log += " {:.3f}/{:.3f}/{:.3f}({:.0f}),".format(
                np.mean(total_success),
                np.mean(total_spl),
                np.mean(total_dist),
                len(total_spl))
        print(log)
        log += ""
        log += "--"*40
        logger.info(log)


def obtain_final_res(args, logger, episode_success, episode_spl, episode_dist, spl_per_category, success_per_category):
    if args.eval:
        print("Dumping eval details...")
        total_success = []
        total_spl = []
        total_dist = []
        log = ""
        for e in range(args.num_processes):
            for acc in episode_success[e]:
                total_success.append(acc)
            for dist in episode_dist[e]:
                total_dist.append(dist)
            for spl in episode_spl[e]:
                total_spl.append(spl)
        if len(total_spl) > 0:
            log = "Final ObjectNav succ/spl/dtg:"
            log += " {:.3f}/{:.3f}/{:.3f}({:.0f}),".format(
                np.mean(total_success),
                np.mean(total_spl),
                np.mean(total_dist),
                len(total_spl))
        print(log)
        log += ""
        log += "--"*40
        logger.info(log)
        # Save the spl per category
        log = "Success | SPL per category"
        for key in success_per_category:
            log += "{}: {} | {}".format(key,
                                          sum(success_per_category[key]) /
                                          len(success_per_category[key]),
                                          sum(spl_per_category[key]) /
                                          len(spl_per_category[key]))
        print(log)
        log += ""
        log += "--"*40
        logger.info(log)
        dump_dir = Path(
            "{}/dump/{}/".format(args.dump_location, args.exp_name))
        with open('{}/{}_spl_per_cat_pred_thr.json'.format(
                dump_dir, args.split), 'w') as f:
            json.dump(spl_per_category, f)
        with open('{}/{}_success_per_cat_pred_thr.json'.format(
                dump_dir, args.split), 'w') as f:
            json.dump(success_per_category, f)


def replace_batch_slices(tmp_goals_masks, obs, start_index=6):
    """
    Replace the batch dimension of n2 array starting from the specified index with the content of n1 array.
    Parameters:
    n1: numpy array, shape (batch1, height, width)
    n2: numpy array, shape (batch2, height, width)
    start_index: starting batch index for replacement (default is 6)
    Returns:
    modified copy of n2 array
    Exceptions:
    Raises ValueError if the number of batches in n2 is insufficient for the replacement operation.
    """
    # Get dimension information
    batch1 = tmp_goals_masks.shape[0]  # tmp_goals_masks.shape=(num,120,160)
    batch2 = obs.shape[1]
    h1, w1 = tmp_goals_masks.shape[1:]
    h2, w2 = obs.shape[2:]  # obs.shape=(1, 20, 120,160)
    # Validate dimension matching
    if h1 != h2 or w1 != w2:
        raise ValueError(f"Spatial dimensions do not match: n1({h1}x{w1}) vs n2({h2}x{w2})")
    # Validate sufficient batch count
    if start_index + batch1 > batch2:
        required = start_index + batch1
        raise ValueError(
            f"n2 batch count ({batch2}) is insufficient, at least {required} batches are needed."
            f"Current batch1 size is {batch1}, minimum replacement index is {start_index}"
        )
    # Create a copy of n2 (to avoid modifying the original data)
    result = obs.clone()
    # Perform slice replacement [6,7](@ref)
    replace_slice = slice(start_index, start_index + batch1)
    result[replace_slice] = torch.from_numpy(tmp_goals_masks)
    return result


def generate_centroid_masks(full_map, min_index, max_index):
    """
    Generate binary masks for the centroid of non-zero regions in channels.
    Parameters:
    full_map: numpy array, shape (1, 20, height, width)
    min_index: starting channel index (inclusive)
    max_index: ending channel index (inclusive)
    Returns:
    masks: numpy array, shape (n_slices, height, width), dtype=np.uint8
    Exceptions:
    Raises ValueError if indices are out of range or min_index > max_index.
    """
    # Validate input validity
    if min_index < 0 or max_index >= full_map.shape[1]:
        raise ValueError(f"Index out of range. Valid channel range: [0, {full_map.shape[1]-1}]")
    if min_index > max_index:
        raise ValueError("min_index cannot be greater than max_index")
    # Extract target channel data (n_slices, height, width)
    channels = full_map[0, min_index:max_index+1]
    n_slices, height, width = channels.shape
    # Initialize output masks (all zeros)
    masks = np.zeros((n_slices, height, width), dtype=np.uint8)
    for i in range(n_slices):
        channel_data = channels[i]
        # Check if the channel is all zeros
        if np.all(channel_data == 0):
            continue  # keep all zeros
        # Get coordinates of non-zero pixels
        y_coords, x_coords = np.nonzero(channel_data)
        # Calculate centroid coordinates (rounded to nearest integer)
        y_center = int(np.round(np.mean(y_coords)))
        x_center = int(np.round(np.mean(x_coords)))
        # Set 1 at the centroid position
        masks[i, y_center, x_center] = 1
    return masks


def get_agent_position_mask(planner_pose_inputs, full_map_shape, args, env_idx=0):
    """
    Get a binary mask of the agent's position in the full_map.
    Args:
        planner_pose_inputs: planner pose inputs (num_scenes, 7)
        full_map: full map (num_scenes, channels, height, width)
        args: argument object, containing map_resolution
        env_idx: environment index, default is 0
    Returns:
        numpy.ndarray: binary mask of the same size as full_map (height, width), 1 at agent position, 0 elsewhere
    """
    # Get agent's global position coordinates
    global_x = planner_pose_inputs[env_idx, 0]
    global_y = planner_pose_inputs[env_idx, 1]
    # Get map dimensions
    map_height, map_width = full_map_shape[2], full_map_shape[3]  # typically 480x480
    # Convert to map pixel coordinates
    pixel_row = int(global_y * 100.0 / args.map_resolution)
    pixel_col = int(global_x * 100.0 / args.map_resolution)
    # Create zero mask
    agent_mask = np.zeros((map_height, map_width), dtype=np.uint8)
    # Boundary check and set agent position
    if 0 <= pixel_row < map_height and 0 <= pixel_col < map_width:
        agent_mask[pixel_row, pixel_col] = 1
    return agent_mask


def expand_masks(original_masks, target_area=10000):
    """
    Expand circular regions for each channel.
    Parameters:
        original_masks: original mask array, shape (num, height, width)
        target_area: target area value (integer)
    Returns:
        expanded mask array, shape (num, height, width)
    """
    if original_masks.shape[0] == 0:
        return None
    # Copy original array, do not modify original data
    expanded_masks = original_masks.copy()
    num, height, width = expanded_masks.shape
    # Validate parameters
    if target_area <= 1:
        raise ValueError("Target area value must be greater than 1")
    # Calculate center position for each channel
    centers = []
    for i in range(num):
        # Find positions with value 1
        positions = np.argwhere(expanded_masks[i] == 1)
        if len(positions) == 0:
            raise ValueError(f"No position with value 1 found in channel {i}")
        # Take the first position as the center
        y, x = positions[0]
        centers.append((y, x))
    # Process each channel
    for i in range(num):
        cy, cx = centers[i]  # center position
        # Calculate maximum possible radius (limited by boundaries)
        top_dist = cy
        bottom_dist = height - 1 - cy
        left_dist = cx
        right_dist = width - 1 - cx
        max_r = min(top_dist, bottom_dist, left_dist, right_dist)
        # Calculate target radius (theoretically, without considering boundaries)
        target_radius = sqrt(target_area / pi)
        # Calculate actual feasible radius
        if max_r >= target_radius:
            # If the target radius can fit within the image boundaries, use the target radius
            final_radius = target_radius
        else:
            # If the target radius is larger than the maximum allowed radius, use the maximum allowed radius
            final_radius = max_r
        # Create grid coordinates
        Y, X = np.ogrid[:height, :width]
        # Calculate distance from each point to the center
        dist = np.sqrt((X - cx)**2 + (Y - cy)**2)
        # Create circular region mask
        circle_mask = dist <= final_radius
        # Update mask for current channel
        expanded_masks[i] = circle_mask.astype(np.uint8)
    return expanded_masks


# def visualize_obs_exp_agent_masks(obstacle_mask, exp_mask, agent_position_mask, output_path="./visualize_obs_exp_agent_masks.png", dpi=300):
#     """
#     Visualize three binary masks.
#     Parameters:
#         mask1: first mask (height, width), 1=light gray, 0=white
#         mask2: second mask (height, width), 1=light gray, 0=white
#         mask3: third mask (height, width), 1=red, 0=white
#         output_path: save path (None to not save)
#         dpi: image resolution
#     Returns:
#         matplotlib Figure object
#     """
#     # Check if input shapes are consistent
#     if not (mask1.shape == mask2.shape == mask3.shape):
#         raise ValueError("All masks must have the same shape")
#     height, width = mask1.shape
#     visualized = np.ones((height, width, 3))  # white background (RGB)
#     # Define colors
#     light_gray = [0.8, 0.8, 0.8]  # light gray
#     red = [1.0, 0.0, 0.0]        # red
#     # Process first mask (light gray)
#     visualized[mask1 == 1] = light_gray
#     # Process second mask (light gray)
#     visualized[mask2 == 1] = light_gray
#     # Process third mask (red)
#     visualized[mask3 == 1] = red
#     # Create image
#     fig = plt.figure(figsize=(10, 10))
#     plt.imshow(visualized)
#     plt.axis('off')
#     # Save image
#     if output_path is not None:
#         plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=dpi)
#         print(f"Visualization result saved to {output_path}")
#     return fig
# def get_agent_position_mask(planner_pose_inputs, full_map, args, env_idx=0):
#     """
#     Get a binary mask of the agent's position in the full_map.
#     Args:
#         planner_pose_inputs: planner pose inputs (num_scenes, 7)
#         full_map: full map (num_scenes, channels, height, width)
#         args: argument object, containing map_resolution
#         env_idx: environment index, default is 0
#     Returns:
#         numpy.ndarray: binary mask of the same size as full_map (height, width), 1 at agent position, 0 elsewhere
#     """
#     # Get agent's global position coordinates
#     global_x = planner_pose_inputs[env_idx, 0]
#     global_y = planner_pose_inputs[env_idx, 1]
#     # Get map dimensions
#     map_height, map_width = full_map.shape[2], full_map.shape[3]  # typically 480x480
#     # Convert to map pixel coordinates
#     pixel_row = int(global_y * 100.0 / args.map_resolution)
#     pixel_col = int(global_x * 100.0 / args.map_resolution)
#     # Create zero mask
#     agent_mask = np.zeros((map_height, map_width), dtype=np.uint8)
#     # Boundary check and set agent position
#     if 0 <= pixel_row < map_height and 0 <= pixel_col < map_width:
#         agent_mask[pixel_row, pixel_col] = 1
#     return agent_mask
# def fill_agent_position_mask(mask1, mask2, agent_position_mask):
#     """Merge masks and fill the area around the agent"""
#     # 1. Merge two masks
#     combined_mask = np.logical_or(mask1, mask2).astype(np.uint8)
#     # 2. Locate agent position
#     agent_pos = np.argwhere(agent_position_mask == 1)
#     if len(agent_pos) == 0:
#         raise ValueError("No value of 1 found in agent_position_mask")
#     y, x = agent_pos[0]
#     # 3. Verify that the position is 0 (not occupied)
#     if combined_mask[y, x] != 0:
#         raise ValueError("Agent position is already occupied in combined_mask")
#     # 4. Create filled mask (ensure original image is not modified)
#     filled_mask = combined_mask.copy()
#     # 5. Fill connected region
#     # Create temporary image (add borders to prevent contour overflow)
#     temp_mask = np.zeros(
#         (filled_mask.shape[0]+2, filled_mask.shape[1]+2), dtype=np.uint8)
#     temp_mask[1:-1, 1:-1] = filled_mask
#     # Perform flood fill (start filling 0 region from agent position)
#     cv2.floodFill(
#         image=temp_mask,
#         mask=None,
#         seedPoint=(x+1, y+1),  # adjust boundary offset
#         newVal=1,
#         loDiff=0,
#         upDiff=0,
#         flags=4 | (1 << 8)  # 4-connectivity + fixed range fill
#     )
#     # Remove temporary borders
#     filled_mask = temp_mask[1:-1, 1:-1]
#     return combined_mask, filled_mask


def fill_agent_position_mask(mask1, mask2, agent_position_mask):
    """
    Process complex logic for three masks.
    Parameters:
        mask1: first binary mask (height, width)
        mask2: second binary mask (height, width)
        agent_position_mask: agent position mask (height, width)
    Returns:
        processed combined_mask (height, width)
    """
    # 1. Ensure all inputs are on CPU and converted to numpy arrays
    if torch.is_tensor(mask1):
        mask1 = mask1.cpu().numpy()
    if torch.is_tensor(mask2):
        mask2 = mask2.cpu().numpy()
    if torch.is_tensor(agent_position_mask):
        agent_position_mask = agent_position_mask.cpu().numpy()
    # 2. Merge two masks
    combined_mask = np.logical_or(mask1, mask2).astype(np.uint8)
    # 3. Find agent position
    agent_pos = np.argwhere(agent_position_mask == 1)
    if len(agent_pos) == 0:
        raise ValueError("No 1 found in agent_position_mask")
    agent_y, agent_x = agent_pos[0]  # take the position of the first 1
    # Check if the agent position is 0 in combined_mask
    if combined_mask[agent_y, agent_x] != 0:
        raise ValueError("Agent position is not 0 in combined_mask")
    # 4. Create region mask (start filling from agent position)
    region_mask = np.zeros_like(combined_mask)
    stack = [(agent_y, agent_x)]
    # 5. Use flood fill algorithm to find connected region
    while stack:
        y, x = stack.pop()
        if y < 0 or y >= region_mask.shape[0] or x < 0 or x >= region_mask.shape[1]:
            continue
        if combined_mask[y, x] != 0 or region_mask[y, x] == 1:
            continue
        region_mask[y, x] = 1
        # Add adjacent pixels
        stack.append((y+1, x))
        stack.append((y-1, x))
        stack.append((y, x+1))
        stack.append((y, x-1))
    # 6. Fill region into combined_mask
    combined_mask[region_mask == 1] = 1
    return combined_mask


def visualize_masks(original, filled):
    """Visualize mask comparison"""
    original = original.cpu().numpy()
    # filled = filled.cpu().numpy()
    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    plt.imshow(original, cmap='gray')
    plt.title("Original combined mask")
    plt.axis('off')
    plt.subplot(122)
    plt.imshow(filled, cmap='gray')
    plt.title("Filled mask")
    plt.axis('off')
    plt.tight_layout()
    plt.savefig("./tmp/visualize_filled_masks.png")
    plt.close()


def filter_candidate_goals(filled_mask, candidate_goals, obstacle_mask, exp_mask, agent_position_mask, k, threshold,
                           dis_thresh, obstacle_radius=10, obstacle_threshold=0.3, trajectory_mask=None,
                           trajectory_radius=10):
    """
    Filter candidate goal channels: based on the proportion of 0 values within a circular area, ensure the goal point is not on an obstacle and is in an explored area, filter out points with too many surrounding obstacles, points too close to the robot, and avoid repeatedly visiting trajectory areas.
    :param filled_mask: binary mask (H, W)
    :param obstacle_mask: obstacle mask (H, W)
    :param exp_mask: explored area mask (H, W)
    :param robot_mask: robot position mask (H, W), 1 at robot position, 0 elsewhere
    :param candidate_goals: candidate goals (N, H, W)
    :param k: radius of the circular area (pixels)
    :param threshold: threshold for proportion of 0 values (0.0-1.0)
    :param dis_thresh: minimum distance threshold from the robot (pixels)
    :param obstacle_radius: radius to check obstacles (pixels), default 8
    :param obstacle_threshold: threshold for obstacle proportion (0.0-1.0), default 0.5
    :param trajectory_mask: agent history trajectory mask (H, W), 1 indicates previously visited positions, default None
    :param trajectory_radius: trajectory detection circular radius (pixels), default None
    :return: filtered candidate goal array
    """
    # Get robot position
    robot_positions = np.argwhere(agent_position_mask == 1)
    if len(robot_positions) == 0:
        raise ValueError("No 1 found in agent_position_mask")
        # If no robot position is found, return empty array
        # return np.empty((0, *filled_mask.shape))
    # Assume only one robot position, take the first
    robot_y, robot_x = robot_positions[0]
    # Initialize list of valid channels
    valid_channels = []
    # Iterate over each candidate channel
    for i in range(candidate_goals.shape[0]):
        channel = candidate_goals[i]
        # Check if the channel contains a goal point
        if np.any(channel == 1):
            # Get goal point coordinates (y, x)
            y, x = np.argwhere(channel == 1)[0]
            # Check if the goal point is on an obstacle or in an unexplored area
            if obstacle_mask[y, x] == 1 or exp_mask[y, x] == 0:
                continue
            # Check obstacle density around the goal point
            y_grid, x_grid = np.indices(obstacle_mask.shape)
            obstacle_dist_mask = (
                (y_grid - y)**2 + (x_grid - x)**2) <= obstacle_radius**2
            # Calculate proportion of obstacles within the circular area
            obstacle_circle_area = obstacle_dist_mask.sum()
            obstacle_count = np.logical_and(
                obstacle_mask.cpu().numpy(), obstacle_dist_mask).sum()
            obstacle_ratio = obstacle_count / \
                obstacle_circle_area if obstacle_circle_area > 0 else 0
            # If obstacle proportion exceeds threshold, filter out this point
            if obstacle_ratio > obstacle_threshold:
                continue
            # Generate circular mask
            y_grid, x_grid = np.indices(filled_mask.shape)
            dist_mask = ((y_grid - y)**2 + (x_grid - x)**2) <= k**2
            # Calculate proportion of 0 values within the circular area
            circle_area = dist_mask.sum()
            zero_count = np.logical_and(filled_mask == 0, dist_mask).sum()
            zero_ratio = zero_count / circle_area if circle_area > 0 else 0
            # Decide whether to keep based on threshold
            if zero_ratio > threshold:
                # Check distance from the robot
                robot_distance = np.sqrt((y - robot_y)**2 + (x - robot_x)**2)
                # Only proceed to the next check if the distance is greater than or equal to the threshold
                if robot_distance >= dis_thresh:
                    # Finally filter: check if there is agent history trajectory around the goal point
                    if trajectory_mask is not None and trajectory_radius is not None:
                        # Generate trajectory detection circular mask
                        trajectory_dist_mask = (
                            (y_grid - y)**2 + (x_grid - x)**2) <= trajectory_radius**2
                        # Check if the circular area contains history trajectory
                        trajectory_in_circle = np.logical_and(
                            trajectory_mask, trajectory_dist_mask)
                        # If there is history trajectory in the circle, filter out this point
                        if np.any(trajectory_in_circle):
                            continue
                    # Pass all filtering conditions, keep this candidate goal
                    valid_channels.append(channel)
    # Stack valid channels
    return np.stack(valid_channels) if valid_channels else np.empty((0, *filled_mask.shape))


def sort_candidate_goals_by_distance(candidate_goals, agent_position):
    """
    Sort candidate goals by distance to the agent position (ascending order).
    :param candidate_goals: shape (num_candidate, height, width), each channel contains only one 1
    :param agent_position: shape (height, width), contains only one 1
    :return: sorted candidate goal array by distance
    """
    # 1. Locate agent position
    agent_pos = np.argwhere(agent_position == 1)[0]  # format [y, x]
    # 2. Calculate position and distance for each candidate goal
    distances = []
    for i in range(candidate_goals.shape[0]):
        # Get coordinates of 1 in current channel
        goal_pos = np.argwhere(candidate_goals[i] == 1)[0]  # format [y, x]
        # Calculate Euclidean distance (square root of sum of squares)
        distance = np.sqrt((goal_pos[0] - agent_pos[0])**2 +
                           (goal_pos[1] - agent_pos[1])**2)
        distances.append((i, distance, goal_pos))
    # 3. Sort by distance in ascending order (use stable sort to maintain original order for equal distances)
    sorted_indices = np.argsort([d[1] for d in distances], kind='stable')
    # 4. Reconstruct sorted candidate goal array
    sorted_goals = np.zeros_like(candidate_goals)
    for new_idx, old_idx in enumerate(sorted_indices):
        sorted_goals[new_idx] = candidate_goals[distances[old_idx][0]]
    return sorted_goals


# def get_final_candidate_goal_map(full_map, candidate_goal_map_list_or_numpy, planner_pose_inputs, args):
#     obstacle_mask = full_map[0, 0, :, :] > 0
#     exp_mask = full_map[0, 1, :, :] > 0
#     agent_position_mask = get_agent_position_mask(
#         planner_pose_inputs, full_map.shape, args)
#     #
#     combined_mask = np.logical_or(
#         obstacle_mask.cpu().numpy(), exp_mask.cpu().numpy()).astype(np.uint8)
#     # filled_mask = fill_agent_position_mask(
#     #     obstacle_mask, exp_mask, agent_position_mask)
#     if type(candidate_goal_map_list_or_numpy) == list:
#         candidate_goal_maps = np.concatenate(
#             candidate_goal_map_list_or_numpy, axis=0)
#     else:
#         candidate_goal_maps = candidate_goal_map_list_or_numpy
#     filtered_candidate_goal_map = filter_candidate_goals(
#         combined_mask, candidate_goal_maps, obstacle_mask, exp_mask, agent_position_mask, k=8, threshold=0.2, obstacle_radius=8, dis_thresh=5, obstacle_threshold=0.5)
#     if filtered_candidate_goal_map.shape[0] == 0:
#         return combined_mask, obstacle_mask, None
#     sorted_candidate_goal_map = sort_candidate_goals_by_distance(
#         filtered_candidate_goal_map, agent_position_mask)
#     return combined_mask, obstacle_mask, sorted_candidate_goal_map


def update_final_condidate_goal_map(full_map, candidate_goal_map_list_or_numpy, NextGoalIterator, planner_pose_inputs, args, trajectory_mask=None):
    if (type(candidate_goal_map_list_or_numpy) == list) and len(candidate_goal_map_list_or_numpy):
        new = np.concatenate(candidate_goal_map_list_or_numpy, axis=0)
        if NextGoalIterator is not None:
            left = NextGoalIterator.get_remaining_goals()
            if left is not None:
                all_new_candidate_goal_map = np.concatenate(
                    (left, new), axis=0)
            else:
                all_new_candidate_goal_map = new
        else:
            all_new_candidate_goal_map = new
    else:
        all_new_candidate_goal_map = candidate_goal_map_list_or_numpy
    agent_position_mask = get_agent_position_mask(
        planner_pose_inputs, full_map.shape, args)
    obstacle_mask = full_map[0, 0, :, :] > 0
    exp_mask = full_map[0, 1, :, :] > 0
    combined_mask = np.logical_or(
        obstacle_mask.cpu().numpy(), exp_mask.cpu().numpy()).astype(np.uint8)
    filtered_candidate_goal_map = filter_candidate_goals(
        combined_mask, all_new_candidate_goal_map, obstacle_mask, exp_mask, agent_position_mask, k=8, threshold=0.2, obstacle_radius=8, dis_thresh=5, obstacle_threshold=0.5, trajectory_mask=trajectory_mask)
    sorted_candidate_goal_map = sort_candidate_goals_by_distance(
        filtered_candidate_goal_map, agent_position_mask)
    return sorted_candidate_goal_map


def visualize_goals(filled_mask, candidate_goals, obstacle_map, k, threshold, output_path="result.png"):
    """
    Comprehensive visualization function.
    :param filled_mask: binary mask (H, W)
    :param candidate_goals: candidate goal points (N, H, W)
    :param obstacle_map: obstacle map (H, W)
    :param k: statistics radius
    :param threshold: percentage threshold
    :param output_path: output image path
    :return: visualization image array
    """
    # 1. Create base RGB canvas
    h, w = filled_mask.shape
    base_img = np.zeros((h, w, 3), dtype=np.uint8) + 255  # white background
    # 2. Fill base area colors [1,5](@ref)
    # Grayscale mapping: obstacle dark gray (100), normal area gray (200), blank white (255)
    gray_areas = np.where((filled_mask == 1) & (obstacle_map == 0))
    dark_gray_areas = np.where((filled_mask == 1) & (obstacle_map == 1))
    base_img[gray_areas[0], gray_areas[1]] = [200, 200, 200]       # gray
    base_img[dark_gray_areas[0], dark_gray_areas[1]] = [100, 100, 100]  # dark gray
    # 3. Calculate percentage for each goal point [9](@ref)
    percentages = []
    positions = []
    for i in range(candidate_goals.shape[0]):
        # Locate goal point coordinates
        y, x = np.argwhere(candidate_goals[i] == 1)[0]
        positions.append((x, y))
        # Calculate proportion of 0s within the circular area
        x_min, x_max = max(0, x-k), min(w, x+k+1)
        y_min, y_max = max(0, y-k), min(h, y+k+1)
        region = filled_mask[y_min:y_max, x_min:x_max]
        y_grid, x_grid = np.ogrid[y_min:y_max, x_min:x_max]
        mask = ((x_grid - x)**2 + (y_grid - y)**2) <= k**2
        zero_pixels = np.sum((region == 0) & mask)
        total_pixels = np.sum(mask)
        percent = (zero_pixels / total_pixels) * 100 if total_pixels > 0 else 0
        percentages.append(percent)
    # 4. Draw goal points and percentage text (with anti-overlap)
    text_positions = []  # record occupied text positions
    for i, (x, y) in enumerate(positions):
        # Draw red goal point
        cv2.circle(base_img, (x, y), 2, (0, 0, 255), -1)
        # Determine text color
        text_color = (0, 255, 0) if percentages[i] < threshold else (
            255, 0, 0)  # green/red
        # Dynamically calculate text position (anti-overlap)
        offset_y = -15  # default above the point
        for pos in text_positions:
            if abs(x - pos[0]) < 40 and abs(y - pos[1]) < 20:
                offset_y = 15  # change to below if overlapping
                break
        text_pos = (x - 15, y + offset_y)
        text_positions.append(text_pos)
        # Draw percentage text
        text = f"{percentages[i]:.1f}%"
        cv2.putText(base_img, text, text_pos,
                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, text_color, 1)
    # 5. Save result
    cv2.imwrite(output_path, base_img)
    return base_img


def get_candidate_goal_mask_and_replace_obs(depth, graph, graph2, obs, args, device):
    candidate_goal_mask = graph.process_depth(
        depth)  # candidate_goal_mask.shape=(num_candidate_goals,960, 1280) ,num_candidate_goals>=0
    if candidate_goal_mask.shape[0] == 0:
        candidate_goal_mask = graph2.process_depth(depth)
    expand_candidate_goal_mask = expand_masks(
        candidate_goal_mask)
    if expand_candidate_goal_mask is None:  # no candidate exploration direction in current view. also do not update obs
        return np.zeros_like(candidate_goal_mask), obs
    num_candidate_goals = expand_candidate_goal_mask.shape[0]
    ds = args.env_frame_width // args.frame_width
    if ds != 1:
        expand_candidate_goal_mask = expand_candidate_goal_mask[:,
                                                                ds // 2::ds, ds // 2::ds]
    obs[0, 4:4+num_candidate_goals, :,
        :] = torch.from_numpy(expand_candidate_goal_mask).to(device)
    return candidate_goal_mask, obs


def is_stuck(position_history, current_mask_pos, window_size=5, threshold=0.5):
    position_history.append(current_mask_pos)
    if len(position_history) == window_size:
        pos_array = np.array(position_history)
        std_dev = np.std(pos_array, axis=0).mean()  # calculate coordinate standard deviation
        return std_dev < threshold  # threshold adjusted according to actual scenario
    return False


def is_goal_blocked(obstacle_map, tmp_goal, k=8, thresh=0.45):
    """
    Check if the proportion of obstacles within a circle of radius k around the goal point exceeds the threshold thresh.
    Parameters:
        obstacle_map (np.ndarray or torch.Tensor): obstacle map, 0 indicates free space, 1 indicates obstacle
        tmp_goal (np.ndarray or torch.Tensor): goal point map, only one 1, rest are 0
        k (int): radius of the circle
        thresh (float): obstacle proportion threshold (between 0 and 1)
    Returns:
        bool: returns True if obstacle proportion > thresh; otherwise returns False
    """
    import torch
    # Determine if input is a torch tensor and handle accordingly
    is_obstacle_tensor = torch.is_tensor(obstacle_map)
    is_goal_tensor = torch.is_tensor(tmp_goal)
    # Convert to numpy for processing
    if is_obstacle_tensor:
        obstacle_np = obstacle_map.cpu().numpy()
    else:
        obstacle_np = obstacle_map
    if is_goal_tensor:
        tmp_goal_np = tmp_goal.cpu().numpy()
    else:
        tmp_goal_np = tmp_goal
    # Check input validity
    if obstacle_np.shape != tmp_goal_np.shape:
        raise ValueError("obstacle_map and tmp_goal must have the same shape")
    if np.sum(tmp_goal_np == 1) != 1:
        raise ValueError("tmp_goal must have exactly one 1")
    # Find the position of 1 in tmp_goal (y, x)
    y, x = np.argwhere(tmp_goal_np == 1)[0]
    height, width = obstacle_np.shape
    # Generate grid coordinates
    yy, xx = np.mgrid[:height, :width]
    # Calculate Euclidean distance from each point to the goal point
    distances = np.sqrt((yy - y)**2 + (xx - x)**2)
    # Get pixels within the circle (distance <= k)
    circle_mask = distances <= k
    # Calculate proportion of obstacles within the circle
    circle_area = np.sum(circle_mask)
    if circle_area == 0:
        return False  # no pixels within the circle (theoretically impossible unless k=0)
    obstacle_count = np.sum(obstacle_np[circle_mask] == 1)
    obstacle_ratio = obstacle_count / circle_area
    return obstacle_ratio > thresh


def get_Node_based_full_map(full_map, NodeDetector, planner_pose_inputs, args, infos):
    obstacle_mask = full_map[0, 0, :, :] > 0
    exp_mask = full_map[0, 1, :, :] > 0
    obs_exp = np.stack(
        [obstacle_mask.cpu().numpy(), exp_mask.cpu().numpy()], axis=0)
    trajectory_mask = infos[0]['trajectory_mask']
    agent_posion_mask = get_agent_position_mask(
        planner_pose_inputs, full_map.shape, args)
    node_map = NodeDetector.run(obs_exp, trajectory_mask, agent_posion_mask)
    if node_map.shape[0] != 0:
        node_ls = []
        for i in range(node_map.shape[0]):
            if is_path_reachable(obstacle_mask, exp_mask, agent_posion_mask, node_map[i]):
                node_ls.append(node_map[i])
        if len(node_ls) != 0:
            return np.stack(node_ls, axis=0)
        else:
            return None
    else:
        return None


def sample_escape_goal(obstacle_map, exp_map, history_position, initial_radius=20, max_radius=100):
    """
    Sample a traversable escape goal point near the stuck position.
    Parameters:
        obstacle_map (np.ndarray): obstacle map, 1=obstacle
        exp_map (np.ndarray): traversable map, 1=traversable
        history_position (np.ndarray): history position mask, only one 1
        initial_radius (int): initial search radius (pixels)
        max_radius (int): maximum search radius (pixels)
    Returns:
        tuple: (y, x) sampled point coordinates, returns None if no valid point
    """
    # Validate inputs
    assert obstacle_map.shape == exp_map.shape == history_position.shape
    assert np.sum(history_position) == 1, "History position mask should have exactly one 1"
    # Locate stuck position
    center_y, center_x = np.argwhere(history_position == 1)[0]
    height, width = obstacle_map.shape
    # Generate grid coordinates for candidate points
    y_coords, x_coords = np.indices((height, width))
    # Calculate Euclidean distance from each point to the center
    distances = np.sqrt((y_coords - center_y)**2 + (x_coords - center_x)**2)
    current_radius = initial_radius
    while current_radius <= max_radius:
        # Create search mask (circular area + distance constraint)
        radius_mask = distances <= current_radius
        # Create traversable mask (exclude obstacles)
        valid_mask = (exp_map == 1) & (obstacle_map == 0)
        # Combine conditions: within circular area and traversable
        candidate_mask = radius_mask & valid_mask
        # Get coordinates of candidate points
        candidate_points = np.argwhere(candidate_mask)
        if len(candidate_points) > 0:
            # Randomly select one candidate point
            chosen_idx = random.randint(0, len(candidate_points)-1)
            return tuple(candidate_points[chosen_idx])
        # Expand search radius (exponential growth to avoid infinite loop)
        current_radius = min(int(current_radius * 1.5), max_radius)
    return None  # no valid point found


def check_consistent_positions(history_position, planner_pose_inputs, full_map, args, k=25, min_common=20):
    """
    Check if the position of value 1 in all arrays in the deque is the same.
    Parameters:
        queue (deque): deque, each element is a NumPy array of shape (height, width) with only one 1
        k (int): expected queue length threshold (e.g., k=10)
    Returns:
        bool: 
            - If queue length < k, return False
            - If queue length == k and all 1 positions are the same, return True
            - Otherwise, return False
    """
    agent_position_mask = get_agent_position_mask(
        planner_pose_inputs, full_map.shape, args)
    history_position.append(agent_position_mask)
    # Check queue length
    # Check queue length
    if len(history_position) < k:
        return False
    # Count occurrences of each position
    position_count = {}
    for array in history_position:
        # Get coordinates of position 1 (y, x) and convert to tuple
        pos = tuple(np.argwhere(array == 1)[0])
        position_count[pos] = position_count.get(pos, 0) + 1
    # Check if any position occurs at least min_common times
    for count in position_count.values():
        if count >= min_common:
            return True
    return False


def is_path_reachable(obstacle_map, exp_map, agent_position_mask, tmp_goal_map):
    """
    Determine if there is a traversable path from the agent position to the goal position.
    Parameters:
        obstacle_map (np.ndarray): obstacle map, shape (height, width), 1 indicates obstacle, 0 indicates unknown area
        exp_map (np.ndarray): traversable area map, shape (height, width), 1 indicates traversable area, 0 indicates unknown area
        agent_position_mask (np.ndarray): agent position mask, shape (height, width), only one 1 indicating agent position
        tmp_goal_map (np.ndarray): goal map, shape (height, width), only one 1 indicating goal position
    Returns:
        bool: returns True if a traversable path exists, otherwise returns False
    """
    # Convert all parameters to numpy arrays
    def to_numpy(x):
        if hasattr(x, 'cpu'):  # PyTorch tensor
            return x.cpu().numpy()
        else:  # already a numpy array
            return x
    obstacle_map = to_numpy(obstacle_map)
    exp_map = to_numpy(exp_map)
    agent_position_mask = to_numpy(agent_position_mask)
    tmp_goal_map = to_numpy(tmp_goal_map)
    # Check input validity
    if obstacle_map.shape != exp_map.shape or obstacle_map.shape != agent_position_mask.shape or obstacle_map.shape != tmp_goal_map.shape:
        raise ValueError("All maps must have the same shape")
    # Get map dimensions
    height, width = obstacle_map.shape
    # Generate final traversable map
    # Rule: positions where exp_map is 1 are traversable, but if both obstacle_map and exp_map are 1, the position is an obstacle
    traversable_map = np.zeros_like(obstacle_map, dtype=bool)
    # Traversable condition: exp_map is 1 and obstacle_map is 0
    traversable_map = (exp_map == 1) & (obstacle_map == 0)
    # Find agent position
    agent_positions = np.where(agent_position_mask == 1)
    if len(agent_positions[0]) == 0:
        raise ValueError("No agent position found in agent_position_mask")
    if len(agent_positions[0]) > 1:
        raise ValueError("Multiple 1s found in agent_position_mask, should have only one")
    agent_x, agent_y = agent_positions[0][0], agent_positions[1][0]
    # Find goal position
    goal_positions = np.where(tmp_goal_map == 1)
    if len(goal_positions[0]) == 0:
        raise ValueError("No goal position found in tmp_goal_map")
    if len(goal_positions[0]) > 1:
        raise ValueError("Multiple 1s found in tmp_goal_map, should have only one")
    goal_x, goal_y = goal_positions[0][0], goal_positions[1][0]
    # Check if start and end points are in traversable area
    if not traversable_map[agent_x, agent_y]:
        return False  # agent position is not traversable
    if not traversable_map[goal_x, goal_y]:
        return False  # goal position is not traversable
    # If start and end positions are the same, return True
    if agent_x == goal_x and agent_y == goal_y:
        return True
    # Use BFS to search for path
    visited = np.zeros_like(traversable_map, dtype=bool)
    queue = deque([(agent_x, agent_y)])
    visited[agent_x, agent_y] = True
    # Define 8 directions of movement (up, down, left, right, and four diagonal directions)
    directions = [(-1, -1), (-1, 0), (-1, 1), (0, -1),
                  (0, 1), (1, -1), (1, 0), (1, 1)]
    while queue:
        current_x, current_y = queue.popleft()
        # Check all adjacent positions
        for dx, dy in directions:
            next_x, next_y = current_x + dx, current_y + dy
            # Check boundary conditions
            if 0 <= next_x < height and 0 <= next_y < width:
                # If goal position is reached
                if next_x == goal_x and next_y == goal_y:
                    return True
                # If the position is traversable and not visited
                if traversable_map[next_x, next_y] and not visited[next_x, next_y]:
                    visited[next_x, next_y] = True
                    queue.append((next_x, next_y))
    # If no path is found after BFS, return False
    return False


class CandidateGoalIterator:
    def __init__(self, candidate_goals=None):
        """
        Initialize candidate goal iterator.
        :param candidate_goals: array of shape (num_candidate, height, width)
        """
        self.set_candidate_goals(candidate_goals)

    def set_candidate_goals(self, candidate_goals):
        """
        Set new candidate goal array and reset iteration state.
        :param candidate_goals: new candidate goal array
        """
        self.candidate_goals = candidate_goals
        self.current_index = 0  # reset current index

    def next(self):
        """
        Get the next candidate goal channel.
        :return: (height, width) array or None
        """
        # Check if iteration is complete [2](@ref)
        if self.candidate_goals is None or self.current_index >= len(self.candidate_goals):
            return None
        # Get current channel and update index
        channel = self.candidate_goals[self.current_index]
        self.current_index += 1
        return channel

    def get_remaining_goals(self):
        """
        Get remaining candidate goal channels.
        :return: remaining channels array (remaining_num, height, width) or None
        """
        # Check if there is a candidate goal array
        if self.candidate_goals is None:
            return None
        # Check if there are remaining channels
        if self.current_index >= len(self.candidate_goals):
            return None
        # Return all remaining channels from current index
        remaining_channels = self.candidate_goals[self.current_index:]
        return remaining_channels


# def expand_semantic_mask(semantic_mask, up_scale=2, down_scale=1):
#     processed_mask = semantic_mask.cpu().numpy().copy()
#     num_scenes, num_channels, height, width = semantic_mask.shape
#     for scene_idx in range(num_scenes):
#         for class_idx in range(4, num_channels):
#             mask = processed_mask[scene_idx, class_idx]
#             count = np.sum(mask)
#             if 0 < count < 100:
#                 # Find original x boundaries
#                 ys, xs = np.where(mask)
#                 if len(xs) == 0:
#                     continue
#                 x_start, x_end = np.min(xs), np.max(xs)
#                 # Calculate new x boundaries
#                 width_original = x_end - x_start
#                 delta = int(round(0.05 * width_original))
#                 new_x_start = x_start + delta
#                 new_x_end = x_end - delta
#                 # Apply x-axis contraction
#                 new_mask = np.zeros_like(mask)
#                 new_x_start = max(0, new_x_start)
#                 new_x_end = min(width-1, new_x_end)
#                 if new_x_start <= new_x_end:
#                     new_mask[:, new_x_start:new_x_end +
#                              1] = mask[:, new_x_start:new_x_end+1]
#                 # Find y boundaries after contraction
#                 ys_new, _ = np.where(new_mask)
#                 if len(ys_new) == 0:
#                     processed_mask[scene_idx, class_idx] = new_mask
#                     continue
#                 y_start, y_end = np.min(ys_new), np.max(ys_new)
#                 h = y_end - y_start + 1
#                 # Calculate new y boundaries
#                 up_extend = int(round(up_scale * h))
#                 down_extend = int(round(down_scale * h))
#                 y_start_new = max(0, y_start - up_extend)
#                 y_end_new = min(height-1, y_end + down_extend)
#                 # Apply y-axis expansion
#                 final_mask = np.zeros_like(new_mask)
#                 if new_x_start <= new_x_end:
#                     final_mask[y_start_new:y_end_new +
#                                1, new_x_start:new_x_end+1] = 1
#                 processed_mask[scene_idx, class_idx] = final_mask
#     return torch.from_numpy(processed_mask).to(semantic_mask.device)
# def expand_semantic_channels(obs, target_threshold=400, start_channel=4, end_channel=9):
#     """
#     Expand small target regions in semantic channels to a specified threshold size.
#     Args:
#         obs: tensor or numpy array of shape [1, 30, height, width]
#         target_threshold: target threshold, expand if number of non-zero values is less than this value
#         start_channel: starting channel index (inclusive)
#         end_channel: ending channel index (inclusive)
#     Returns:
#         processed array, keeping the same shape and type as input
#     """
#     # Determine input type and convert to numpy for processing
#     is_tensor = torch.is_tensor(obs)
#     if is_tensor:
#         obs_np = obs.clone().cpu().numpy()
#         device = obs.device
#     else:
#         obs_np = obs.copy()
#     # Create output array
#     output = obs_np.copy()
#     # Get dimension information
#     batch_size, num_channels, height, width = obs_np.shape
#     # Validate input shape
#     if batch_size != 1:
#         raise ValueError(f"Expected batch_size to be 1, but got {batch_size}")
#     # Process specified channel range (closed interval)
#     for channel_idx in range(start_channel, end_channel + 1):
#         if channel_idx >= num_channels:
#             continue
#         # Get mask for current channel
#         current_mask = obs_np[0, channel_idx]
#         # Detect non-zero values
#         nonzero_positions = np.argwhere(current_mask != 0)
#         num_nonzero = len(nonzero_positions)
#         # If no non-zero values or number of non-zero values exceeds threshold, skip
#         if num_nonzero == 0 or num_nonzero >= target_threshold:
#             continue
#         print(
#             f"Channel {channel_idx}: detected {num_nonzero} non-zero values, less than threshold {target_threshold}, starting expansion...")
#         # Calculate center coordinates of non-zero values
#         center_y = np.mean(nonzero_positions[:, 0])
#         center_x = np.mean(nonzero_positions[:, 1])
#         cy, cx = int(round(center_y)), int(round(center_x))
#         print(f"  Center coordinates: ({cy}, {cx})")
#         # Calculate maximum radius limited by boundaries
#         top_dist = cy
#         bottom_dist = height - 1 - cy
#         left_dist = cx
#         right_dist = width - 1 - cx
#         max_radius = min(top_dist, bottom_dist, left_dist, right_dist)
#         # Calculate theoretical target radius
#         target_radius = sqrt(target_threshold / pi)
#         # Determine actual radius to use
#         if max_radius >= target_radius:
#             final_radius = target_radius
#             print(f"  Using target radius: {final_radius:.2f}")
#         else:
#             final_radius = max_radius
#             print(f"  Limited by boundaries, using maximum radius: {final_radius:.2f}")
#         # Create grid coordinates
#         Y, X = np.ogrid[:height, :width]
#         # Calculate distance from each point to center
#         dist = np.sqrt((X - cx)**2 + (Y - cy)**2)
#         # Create circular region mask
#         circle_mask = dist <= final_radius
#         # Calculate actual number of pixels after expansion
#         actual_pixels = np.sum(circle_mask)
#         print(f"  Number of pixels after expansion: {actual_pixels}")
#         # Update mask for current channel (keep original non-zero values)
#         new_mask = np.zeros_like(current_mask)
#         new_mask[circle_mask] = 1
#         # Update mask in output array
#         output[0, channel_idx] = new_mask
#     # Convert back to original type
#     if is_tensor:
#         return torch.from_numpy(output).to(device)
#     else:
#         return output