import numpy as np
import numpy.linalg as LA
import cv2
import os
import argparse
from modeling.utils.baseline_utils import create_folder, pxl_coords_to_pose
import matplotlib
import matplotlib.pyplot as plt
import math
from math import cos, sin, acos, atan2, pi, floor, degrees
import random
from modeling.utils.navigation_utils import change_brightness, SimpleRLEnv, get_obs_and_pose, get_obs_and_pose_by_action
from modeling.utils.baseline_utils import apply_color_to_map, pose_to_coords, gen_arrow_head_marker, read_map_npy, read_occ_map_npy, plus_theta_fn
from modeling.utils.map_utils_pcd_height import SemanticMap
from modeling.localNavigator_Astar import localNav_Astar
import habitat
import habitat_sim
import random
from core import cfg
from modeling.utils import frontier_utils as fr_utils
from modeling.localNavigator_slam import localNav_slam
from skimage.morphology import skeletonize
from modeling.utils.UNet import UNet
import torch
from collections import OrderedDict
from tqdm import tqdm
from collections import deque
from scipy.spatial import distance

def farthest_point_sampling(points, min_distance=0.25, verbose=True):
    """
    Selects a maximal set of points using Farthest Point Sampling (FPS).

    Args:
        points (np.ndarray): Nx3 array of 3D points.
        min_distance (float): Minimum distance constraint.
        verbose (bool): If True, prints progress and statistics.

    Returns:
        np.ndarray: Selected points.
    """
    points = np.array(points)
    selected_nodes = [points[np.random.randint(len(points))]]  # Pick a random initial point
    
    num_points = len(points)
    iterations = 0  # Counter for monitoring
    pbar = tqdm(total=num_points, desc="Selecting Points", disable=not verbose)  # Progress bar

    while True:
        distances = np.array([min(distance.euclidean(p, q) for q in selected_nodes) for p in points])
        farthest_idx = np.argmax(distances)  # Find the farthest point
        max_distance = distances[farthest_idx]  # Max distance found
        
        if verbose:
            print(f"Iteration {iterations}: Selected {len(selected_nodes)} points, Max Dist: {max_distance:.4f}")

        if max_distance < min_distance:
            break  # Stop when all remaining points are too close

        selected_nodes.append(points[farthest_idx])
        iterations += 1
        pbar.update(1)  # Update progress bar

    pbar.close()
    return np.array(selected_nodes)

def filter_maximum_nodes(points, min_distance=0.25):
    points = np.array(points)  # Ensure input is a NumPy array
    selected_nodes = []

    for p in tqdm(points):
        if all(distance.euclidean(p, q) >= min_distance for q in selected_nodes):
            selected_nodes.append(p)

    return np.array(selected_nodes)

def quaternion_to_yaw(q):
    q_x, q_y, q_z, q_w = q
    yaw = np.arctan2(2 * (q_w * q_z + q_x * q_y), 1 - 2 * (q_y**2 + q_z**2))
    return yaw

def transform_poses(poses):
    traj = []
    for pose in poses:
        heading = quaternion_to_yaw(pose[3:])
        traj.append((pose[0], pose[1], pose[2], heading))
    return traj

def nav_with_trajectory(start_pose, split, env, scene_name, saved_folder, max_points):
    scene_height = start_pose[1]

    scene = env.semantic_annotations()
    ins2cat_dict = {
		int(obj.id.split("_")[-1]): obj.category.index()
		for obj in scene.objects
	}
    
    # Initialize necessary modules and variables
    np.random.seed(cfg.GENERAL.RANDOM_SEED)
    random.seed(cfg.GENERAL.RANDOM_SEED)

    # Load ground-truth occupancy map
    occ_map_npy = np.load(
        f'output/semantic_map/{split}/{scene_name}/BEV_occupancy_map.npy',
        allow_pickle=True).item()
    gt_occ_map, pose_range, coords_range, WH = read_occ_map_npy(occ_map_npy)
    H, W = gt_occ_map.shape[:2]

    # Initialize modules
    LN = localNav_Astar(pose_range, coords_range, WH, scene_name)
    LS = localNav_slam(pose_range, coords_range, WH, mark_locs=True, close_small_openings=False, recover_on_collision=False, 
                    fix_thrashing=False, point_cnt=2)
    LS.reset(gt_occ_map)

    semMap_module = SemanticMap(split, scene_name, pose_range, coords_range, WH, ins2cat_dict)

    # Get the area connected to the starting position
    agent_pos = np.array([start_pose[0], start_pose[1], start_pose[2]])
    if not env.is_navigable(agent_pos):
        print(f'Start pose {agent_pos} is not navigable...')
        exit(0)

    gt_reached_area = LN.get_start_pose_connected_component(
        (agent_pos[0], -agent_pos[2], 0), gt_occ_map)
    
    possible_coords = np.nonzero(gt_reached_area)
    all_possible_agent_pos = []
    for i in range(len(possible_coords[0])):
        x, z = possible_coords[1][i], -possible_coords[0][i]
        possible_agent_pos = pxl_coords_to_pose((x, z), pose_range, coords_range, WH)
        possible_agent_pos = np.array([possible_agent_pos[0], scene_height, possible_agent_pos[1]])
        
        all_possible_agent_pos.append(possible_agent_pos)
    
    sparse_agent_pos = filter_maximum_nodes(all_possible_agent_pos, min_distance=0.25)
    print("before filtering: ", len(all_possible_agent_pos), "after filtering: ", len(sparse_agent_pos))
    if len(sparse_agent_pos) > max_points:
        print(f"Current points: {len(sparse_agent_pos)}, larger than max points: {max_points}")
    else:
        print(f"Current points: {len(sparse_agent_pos)}, less than max points: {max_points}, abort")
        return max_points

    # save the sparse_agent_pos
    os.makedirs('../bellman_explore/sparse_agent_pos', exist_ok=True)
    np.save(f'../bellman_explore/sparse_agent_pos/{scene_name}.npy', sparse_agent_pos)

    for idx, agent_pos in tqdm(enumerate(sparse_agent_pos), total=len(sparse_agent_pos)):
        for jdx in range(8):
            # heading angle from -pi to pi
            obs_list, pose_list = [], []
            heading_angle = (jdx - 4) * pi / 4
            obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
            obs_list.append(obs)
            pose_list.append(pose)

            semMap_module.build_semantic_map(obs_list, pose_list, step=idx, saved_folder=saved_folder, get_depth=True)
            # Compute observed area
            agent_map_pose = (pose[0], -pose[1], -pose[2])
            observed_occupancy_map, _, observed_area_flag, _ = semMap_module.get_observed_occupancy_map(agent_map_pose)
        
        if idx % 100 == 0:
            explored_free_space = np.logical_and(gt_reached_area, observed_area_flag)
            percent = 1. * np.sum(explored_free_space) / np.sum(gt_reached_area)
            area = np.sum(observed_area_flag) * 0.0025  # Each cell is 0.05m * 0.05m
            print(f"Coverage: {percent:.2%}, Explored Area: {area:.2f} m^2")

    depth_stat = semMap_module.depth_grid
    os.makedirs('../bellman_explore/all_poses', exist_ok=True)
    np.save(f'../bellman_explore/all_poses/{scene_name}.npy', depth_stat)
    return len(sparse_agent_pos)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        required=False,
                        default='exp_90degree_ANS_NAVMESH_MAP_1000STEPS.yaml')
    args = parser.parse_args()

    cfg.merge_from_file(f'configs/{args.config}')
    cfg.freeze()

    #=============================== basic setup =======================================
    split = 'test'
    if cfg.EVAL.SIZE == 'small':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/{split}_scene_floor_dict.npy',
            allow_pickle=True).item()
    elif cfg.EVAL.SIZE == 'large':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/large_scale_{split}_scene_floor_dict.npy',
            allow_pickle=True).item()

    for env_scene in cfg.MAIN.TEST_SCENE_NO_FLOOR_LIST:

        #================================ load habitat env============================================
        config = habitat.get_config(config_paths=cfg.GENERAL.DATALOADER_CONFIG_PATH)
        config.defrost()
        config.SIMULATOR.SCENE = f'{cfg.GENERAL.HABITAT_SCENE_DATA_PATH}/mp3d/{env_scene}/{env_scene}.glb'
        config.DATASET.SCENES_DIR = cfg.GENERAL.HABITAT_SCENE_DATA_PATH
        config.freeze()
        env = habitat.sims.make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR)

        env.reset()
        scene_dict = scene_floor_dict[env_scene]

        #=============================== traverse each floor ===========================
        for floor_id in list(scene_dict.keys()):
            height = scene_dict[floor_id]['y']
            scene_name = f'{env_scene}_{floor_id}'

            print(f'**********scene_name = {scene_name}***********')

            output_folder = "output/all_positions"
            create_folder(output_folder)
            scene_output_folder = f'{output_folder}/{scene_name}'
            create_folder(scene_output_folder)

            testing_data = scene_dict[floor_id]['start_pose']
            max_points = 0
            for idx, data in enumerate(testing_data):
                data = testing_data[idx]
                print(
                    f'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA EPS {idx} BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB'
                )
                start_pose = data
                print(f'start_pose = {start_pose}')

                agent_pos = np.array([start_pose[0], height, start_pose[1]])
                if not env.is_navigable(agent_pos):
                    print(f'start pose is not navigable ...')
                    continue
                else:
                    print(f'start pose of {agent_pos} is navigable ...')

                saved_folder = f'{scene_output_folder}'
                create_folder(saved_folder, clean_up=False)

                sample_points = nav_with_trajectory(agent_pos, split, env, scene_name, saved_folder, max_points)
                max_points = max(max_points, sample_points)

        env.close()

if __name__ == "__main__":

    main()