import numpy as np
import numpy.linalg as LA
import cv2
import os
import argparse
from modeling.utils.baseline_utils import create_folder
import matplotlib
import matplotlib.pyplot as plt
import math
from math import cos, sin, acos, atan2, pi, floor, degrees
import random
from modeling.utils.navigation_utils import change_brightness, SimpleRLEnv, get_obs_and_pose, get_obs_and_pose_by_action
from modeling.utils.baseline_utils import apply_color_to_map, pose_to_coords, gen_arrow_head_marker, read_map_npy, read_occ_map_npy, plus_theta_fn
from modeling.utils.map_utils_pcd_height import SemanticMap
from modeling.localNavigator_Astar import localNav_Astar
import habitat
import habitat_sim
import random
from core import cfg
from modeling.utils import frontier_utils as fr_utils
from modeling.localNavigator_slam import localNav_slam
from skimage.morphology import skeletonize
from modeling.utils.UNet import UNet
import torch
from collections import OrderedDict

def quaternion_to_yaw(q):
    q_x, q_y, q_z, q_w = q
    yaw = np.arctan2(2 * (q_w * q_z + q_x * q_y), 1 - 2 * (q_y**2 + q_z**2))
    return yaw

def transform_poses(poses):
    traj = []
    for pose in poses:
        heading = quaternion_to_yaw(pose[3:])
        traj.append((pose[0], pose[1], pose[2], heading))
    return traj

def nav_with_trajectory(start_pose, split, env, scene_name, saved_folder, steps=10000):
    scene = env.semantic_annotations()
    for obj in scene.objects:
        print(f"Object ID: {obj.id}, Category: {obj.category.name()}")
        aabb = obj.aabb
        print(aabb.center)
        print(aabb.sizes)
        print("=======")
        # obb = obj.obb
        # print(obb.center)
        # print(obb.half_extents)
        # print(obb.rotation)
    ins2cat_dict = {
		int(obj.id.split("_")[-1]): obj.category.index()
		for obj in scene.objects
	}
    
    # Initialize necessary modules and variables
    np.random.seed(cfg.GENERAL.RANDOM_SEED)
    random.seed(cfg.GENERAL.RANDOM_SEED)

    # Load ground-truth occupancy map
    occ_map_npy = np.load(
        f'output/semantic_map/{split}/{scene_name}/BEV_occupancy_map.npy',
        allow_pickle=True).item()
    gt_occ_map, pose_range, coords_range, WH = read_occ_map_npy(occ_map_npy)
    H, W = gt_occ_map.shape[:2]

    # Initialize modules
    LN = localNav_Astar(pose_range, coords_range, WH, scene_name)
    LS = localNav_slam(pose_range, coords_range, WH, mark_locs=True, close_small_openings=False, recover_on_collision=False, 
                    fix_thrashing=False, point_cnt=2)
    LS.reset(gt_occ_map)

    semMap_module = SemanticMap(split, scene_name, pose_range, coords_range, WH, ins2cat_dict)
    traverse_lst = []
    action_lst = []
    step_cov_pairs = []

    # Get the area connected to the starting position
    agent_pos = np.array([start_pose[0], start_pose[1], start_pose[2]])
    if not env.is_navigable(agent_pos):
        print(f'Start pose is not navigable...')
        return None, None

    gt_reached_area = LN.get_start_pose_connected_component(
        (agent_pos[0], -agent_pos[2], 0), gt_occ_map)

    obs_list, pose_list = [], []
    heading_angle = start_pose[3]
    obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
    obs_list.append(obs)
    pose_list.append(pose)

    step = 0
    # Process the trajectory
    while step < steps:
        print(f'step = {step}')

        pose = pose_list[-1]
        print(f'agent position = {pose[:2]}, angle = {pose[2]}')
        agent_map_pose = (pose[0], -pose[1], -pose[2])
        traverse_lst.append(agent_map_pose)

        # Add observations to the semantic map
        # obs, pose = get_obs_and_pose(env, np.array([x, y, z]), heading)
        semMap_module.build_semantic_map(obs_list, pose_list, step=step, saved_folder=saved_folder)

        # Compute observed area
        observed_occupancy_map, _, observed_area_flag, _ = semMap_module.get_observed_occupancy_map(agent_map_pose)
        
        step += 1

        # random action
        act = random.choice([1, 2, 3])
        action_lst.append(act)

        # output rot is negative of the input angle
        obs_list, pose_list = [], []
        obs, pose = get_obs_and_pose_by_action(env, act)
        obs_list.append(obs)
        pose_list.append(pose)
    
        explored_free_space = np.logical_and(gt_reached_area, observed_area_flag)
        percent = 1. * np.sum(explored_free_space) / np.sum(gt_reached_area)
        area = np.sum(observed_area_flag) * 0.0025  # Each cell is 0.05m * 0.05m
        print(f"Coverage: {percent:.2%}, Explored Area: {area:.2f} m^2")

        step_cov_pairs.append((step, percent, area))

    # plot the final semantic map
    built_semantic_map, observed_area_flag, _ = semMap_module.get_semantic_map(
        )

    color_built_semantic_map = apply_color_to_map(built_semantic_map)
    color_built_semantic_map = change_brightness(color_built_semantic_map,
                                                observed_area_flag,
                                                value=60)

    #=================================== visualize the agent pose as red nodes =======================
    x_coord_lst, z_coord_lst, theta_lst = [], [], []
    for cur_pose in traverse_lst:
        x_coord, z_coord = pose_to_coords((cur_pose[0], cur_pose[1]),
                                        pose_range, coords_range, WH)
        x_coord_lst.append(x_coord)
        z_coord_lst.append(z_coord)
        theta_lst.append(cur_pose[2])

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(25, 10))
    ax[0].imshow(observed_occupancy_map, cmap='gray')
    marker, scale = gen_arrow_head_marker(theta_lst[-1])
    ax[0].scatter(x_coord_lst[-1],
                z_coord_lst[-1],
                marker=marker,
                s=(30 * scale)**2,
                c='red',
                zorder=5)
    ax[0].scatter(x_coord_lst[0],
                z_coord_lst[0],
                marker='s',
                s=50,
                c='red',
                zorder=5)
    #ax.plot(x_coord_lst, z_coord_lst, lw=5, c='blue', zorder=3)
    ax[0].scatter(x_coord_lst, 
            z_coord_lst, 
            c=range(len(x_coord_lst)), 
            cmap='viridis', 
            s=np.linspace(5, 2, num=len(x_coord_lst))**2, 
            zorder=3)
    ax[0].get_xaxis().set_visible(False)
    ax[0].get_yaxis().set_visible(False)
    ax[0].set_title('improved observed_occ_map + frontiers')

    ax[1].imshow(color_built_semantic_map)
    ax[1].get_xaxis().set_visible(False)
    ax[1].get_yaxis().set_visible(False)
    ax[1].set_title('built semantic map')

    fig.tight_layout()
    plt.title('observed area')
    #plt.show()
    fig.savefig(f'{saved_folder}/final_semmap.jpg')
    plt.close()

    # Final statistics
    _, observed_area_flag, _ = semMap_module.get_semantic_map()
    explored_free_area_flag = np.logical_and(gt_occ_map, observed_area_flag)
    percent = np.sum(explored_free_area_flag > 0) / np.sum(gt_occ_map > 0)
    print(f"Final Coverage: {percent:.2%}")
    step_cov_pairs = np.array(step_cov_pairs, dtype='float32')

    return percent, step, traverse_lst, action_lst, step_cov_pairs

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        required=False,
                        default='exp_90degree_ANS_NAVMESH_MAP_1000STEPS.yaml')
    args = parser.parse_args()

    cfg.merge_from_file(f'configs/{args.config}')
    cfg.freeze()

    #=============================== basic setup =======================================
    split = 'train'
    if cfg.EVAL.SIZE == 'small':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/{split}_scene_floor_dict.npy',
            allow_pickle=True).item()
    elif cfg.EVAL.SIZE == 'large':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/large_scale_{split}_scene_floor_dict.npy',
            allow_pickle=True).item()

    for env_scene in ['gTV8FGcVJC9']:

        #================================ load habitat env============================================
        config = habitat.get_config(config_paths=cfg.GENERAL.DATALOADER_CONFIG_PATH)
        config.defrost()
        config.SIMULATOR.SCENE = f'{cfg.GENERAL.HABITAT_SCENE_DATA_PATH}/mp3d/{env_scene}/{env_scene}.glb'
        config.DATASET.SCENES_DIR = cfg.GENERAL.HABITAT_SCENE_DATA_PATH
        config.freeze()
        env = habitat.sims.make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR)

        env.reset()
        scene_dict = scene_floor_dict[env_scene]

        device = torch.device('cuda:0')

        #=============================== traverse each floor ===========================
        for floor_id in list(scene_dict.keys()):
            height = scene_dict[floor_id]['y']
            scene_name = f'{env_scene}_{floor_id}'

            print(f'**********scene_name = {scene_name}***********')

            output_folder = "output/random_traj"
            create_folder(output_folder)
            scene_output_folder = f'{output_folder}/{scene_name}'
            create_folder(scene_output_folder)

            testing_data = scene_dict[floor_id]['start_pose']
            if not cfg.EVAL.USE_ALL_START_POINTS:
                if len(testing_data) > 3:
                    testing_data = testing_data[:3]

            pose_file = f"../HOV-SG/hovsg/data/hm3dsem/metadata/poses/{env_scene}.txt"
            poses = np.loadtxt(pose_file)
            traj = transform_poses(poses)
            start_pose = traj[0]
            print(f'start_pose = {start_pose}')

            saved_folder = f'{scene_output_folder}'
            create_folder(saved_folder, clean_up=False)

            covered_area_percent, steps, trajectory, action_lst, step_cov_pairs = nav_with_trajectory(start_pose, split, env, scene_name, saved_folder)

            result = {}
            result['steps'] = steps
            result['covered_area'] = covered_area_percent
            result['trajectory'] = trajectory
            result['actions'] = action_lst
            result['step_cov_pairs'] = step_cov_pairs

        np.save(f'{output_folder}/results_{scene_name}.npy', result)    
        env.close()

if __name__ == "__main__":

    main()