import numpy as np
import numpy.linalg as LA
import cv2
import os
import argparse
from modeling.utils.baseline_utils import create_folder
import matplotlib
import matplotlib.pyplot as plt
import math
from math import cos, sin, acos, atan2, pi, floor, degrees
import random
from modeling.utils.navigation_utils import change_brightness, SimpleRLEnv, get_obs_and_pose, get_obs_and_pose_by_action
from modeling.utils.baseline_utils import apply_color_to_map, pose_to_coords, gen_arrow_head_marker, read_map_npy, read_occ_map_npy, plus_theta_fn
from modeling.utils.map_utils_pcd_height import SemanticMap
from modeling.localNavigator_Astar import localNav_Astar
import habitat
import random
from core import cfg
from modeling.localNavigator_slam import localNav_slam

def quaternion_to_yaw(q):
    q_x, q_y, q_z, q_w = q
    yaw = np.arctan2(2 * (q_w * q_z + q_x * q_y), 1 - 2 * (q_y**2 + q_z**2))
    return yaw

def transform_poses(poses):
    traj = []
    for pose in poses:
        heading = quaternion_to_yaw(pose[3:])
        traj.append((pose[0], pose[1], pose[2], heading))
    return traj

def nav_with_trajectory(trajectory, split, env, scene_name, saved_folder):
    scene = env.semantic_annotations()
    ins2cat_dict = {
		int(obj.id.split("_")[-1]): obj.category.index()
		for obj in scene.objects
	}
    
    # Initialize necessary modules and variables
    np.random.seed(cfg.GENERAL.RANDOM_SEED)
    random.seed(cfg.GENERAL.RANDOM_SEED)

    # Load ground-truth occupancy map
    occ_map_npy = np.load(
        f'output/semantic_map/{split}/{scene_name}/BEV_occupancy_map.npy',
        allow_pickle=True).item()
    gt_occ_map, pose_range, coords_range, WH = read_occ_map_npy(occ_map_npy)
    H, W = gt_occ_map.shape[:2]

    # Initialize modules
    LN = localNav_Astar(pose_range, coords_range, WH, scene_name)
    LS = localNav_slam(pose_range, coords_range, WH, mark_locs=True, close_small_openings=False, recover_on_collision=False, 
                    fix_thrashing=False, point_cnt=2)
    LS.reset(gt_occ_map)

    semMap_module = SemanticMap(split, scene_name, pose_range, coords_range, WH, ins2cat_dict)
    traverse_lst = []
    step_cov_pairs = []

    # Get the area connected to the starting position
    start_pose = trajectory[0]
    start_pos = np.array([start_pose[0], start_pose[1], start_pose[2]])
    if not env.is_navigable(start_pos):
        print(f'Start pose is not navigable...')
        return None, None

    gt_reached_area = LN.get_start_pose_connected_component(
        (start_pos[0], -start_pos[2], 0), gt_occ_map)

    # Process the trajectory
    for step, agent_pose in enumerate(trajectory):
        print(f"Step {step + 1}/{len(trajectory)}: {agent_pose}")
        x, y, z, heading = agent_pose

        # Convert agent pose to map coordinates
        agent_map_pose = (x, -z, -heading)
        traverse_lst.append(agent_map_pose)

        # Add observations to the semantic map
        obs, pose = get_obs_and_pose(env, np.array([x, y, z]), heading)
        semMap_module.build_semantic_map([obs], [pose], step=step, saved_folder=saved_folder, get_depth=True)

        # Compute observed area
        observed_occupancy_map, _, observed_area_flag, _ = semMap_module.get_observed_occupancy_map(agent_map_pose)
        explored_free_space = np.logical_and(gt_reached_area, observed_area_flag)
        percent = 1. * np.sum(explored_free_space) / np.sum(gt_reached_area)
        area = np.sum(observed_area_flag) * 0.0025  # Each cell is 0.05m * 0.05m
        print(f"Coverage: {percent:.2%}, Explored Area: {area:.2f} m^2")

        step_cov_pairs.append((step, percent, area))

    # plot the final semantic map
    built_semantic_map, observed_area_flag, _ = semMap_module.get_semantic_map(
        )

    color_built_semantic_map = apply_color_to_map(built_semantic_map)
    color_built_semantic_map = change_brightness(color_built_semantic_map,
                                                observed_area_flag,
                                                value=60)

    #=================================== visualize the agent pose as red nodes =======================
    x_coord_lst, z_coord_lst, theta_lst = [], [], []
    for cur_pose in traverse_lst:
        x_coord, z_coord = pose_to_coords((cur_pose[0], cur_pose[1]),
                                        pose_range, coords_range, WH)
        x_coord_lst.append(x_coord)
        z_coord_lst.append(z_coord)
        theta_lst.append(cur_pose[2])

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(25, 10))
    ax[0].imshow(observed_occupancy_map, cmap='gray')
    marker, scale = gen_arrow_head_marker(theta_lst[-1])
    ax[0].scatter(x_coord_lst[-1],
                z_coord_lst[-1],
                marker=marker,
                s=(30 * scale)**2,
                c='red',
                zorder=5)
    ax[0].scatter(x_coord_lst[0],
                z_coord_lst[0],
                marker='s',
                s=50,
                c='red',
                zorder=5)
    #ax.plot(x_coord_lst, z_coord_lst, lw=5, c='blue', zorder=3)
    ax[0].scatter(x_coord_lst, 
            z_coord_lst, 
            c=range(len(x_coord_lst)), 
            cmap='viridis', 
            s=np.linspace(5, 2, num=len(x_coord_lst))**2, 
            zorder=3)
    ax[0].get_xaxis().set_visible(False)
    ax[0].get_yaxis().set_visible(False)
    ax[0].set_title('improved observed_occ_map + frontiers')

    ax[1].imshow(color_built_semantic_map)
    ax[1].get_xaxis().set_visible(False)
    ax[1].get_yaxis().set_visible(False)
    ax[1].set_title('built semantic map')

    fig.tight_layout()
    plt.title('observed area')
    #plt.show()
    fig.savefig(f'{saved_folder}/final_semmap.jpg')
    plt.close()

    # Final statistics
    _, observed_area_flag, _ = semMap_module.get_semantic_map()
    explored_free_area_flag = np.logical_and(gt_occ_map, observed_area_flag)
    percent = np.sum(explored_free_area_flag > 0) / np.sum(gt_occ_map > 0)
    print(f"Final Coverage: {percent:.2%}")
    step_cov_pairs = np.array(step_cov_pairs, dtype='float32')


    depth_stat = semMap_module.depth_grid
    
    # save the depth stat
    os.makedirs(f'../bellman_explore/depth_stat/human', exist_ok=True)
    np.save(f'../bellman_explore/depth_stat/human/{scene_name}.npy', depth_stat)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        required=False,
                        default='exp_90degree_ANS_NAVMESH_MAP_5000STEPS.yaml')
    args = parser.parse_args()

    cfg.merge_from_file(f'configs/{args.config}')
    cfg.freeze()

    #=============================== basic setup =======================================
    split = 'test'
    if cfg.EVAL.SIZE == 'small':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/{split}_scene_floor_dict.npy',
            allow_pickle=True).item()
    elif cfg.EVAL.SIZE == 'large':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/large_scale_{split}_scene_floor_dict.npy',
            allow_pickle=True).item()

    for env_scene in ['UwV83HsGsw3']:

        #================================ load habitat env============================================
        config = habitat.get_config(config_paths=cfg.GENERAL.DATALOADER_CONFIG_PATH)
        config.defrost()
        config.SIMULATOR.SCENE = f'{cfg.GENERAL.HABITAT_SCENE_DATA_PATH}/mp3d/{env_scene}/{env_scene}.glb'
        config.DATASET.SCENES_DIR = cfg.GENERAL.HABITAT_SCENE_DATA_PATH
        config.freeze()
        env = habitat.sims.make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR)

        env.reset()
        scene_dict = scene_floor_dict[env_scene]

        #=============================== traverse each floor ===========================
        for floor_id in list(scene_dict.keys()):
            scene_name = f'{env_scene}_{floor_id}'

            print(f'**********scene_name = {scene_name}***********')

            output_folder = "output/visualize_traj"
            create_folder(output_folder)
            scene_output_folder = f'{output_folder}/{scene_name}'
            create_folder(scene_output_folder)

            pose_file = f"../bellman/bellman-exploration/output/vlmaps_dataset/{env_scene}_1/poses.txt"
            # pose_file = "../bellman/bellman-exploration/output/TESTING_RESULTS_90degree_ANS_NAVMESH_MAP_5000STEPS_nearfar_1.0_3.0/8WUmhLawc2A.txt"
            poses = np.loadtxt(pose_file)
            traj = transform_poses(poses)

            start_pose = traj[0]
            print(f'start_pose = {start_pose}')
            saved_folder = f'{scene_output_folder}'
            create_folder(saved_folder, clean_up=False)

            nav_with_trajectory(traj, split, env, scene_name, saved_folder)

        env.close()

if __name__ == "__main__":

    main()