import numpy as np
import numpy.linalg as LA
import json
import os
import time
import argparse
from modeling.utils.baseline_utils import create_folder
import matplotlib
import matplotlib.pyplot as plt
import math
from math import cos, sin, acos, atan2, pi, floor, degrees
import random
from modeling.utils.navigation_utils import change_brightness, SimpleRLEnv, get_obs_and_pose, get_obs_and_pose_by_action
from modeling.utils.baseline_utils import apply_color_to_map, pose_to_coords, gen_arrow_head_marker, read_map_npy, read_occ_map_npy, plus_theta_fn
from modeling.utils.map_utils_pcd_height import SemanticMap
from modeling.localNavigator_Astar import localNav_Astar
import habitat
import habitat_sim
import random
from core import cfg
from modeling.localNavigator_slam import localNav_slam
from tqdm import tqdm

def nav_with_trajectory(start_pose, scene_height, actions, split, env, scene_name):
    # Get the semantic annotations
    scene = env.semantic_annotations()
    ins2cat_dict = {
		int(obj.id.split("_")[-1]): obj.category.index()
		for obj in scene.objects
	}
    
    # Initialize necessary modules and variables
    np.random.seed(cfg.GENERAL.RANDOM_SEED)
    random.seed(cfg.GENERAL.RANDOM_SEED)

    # Load ground-truth occupancy map
    occ_map_npy = np.load(
        f'output/semantic_map/{split}/{scene_name}/BEV_occupancy_map.npy',
        allow_pickle=True).item()
    gt_occ_map, pose_range, coords_range, WH = read_occ_map_npy(occ_map_npy)
    H, W = gt_occ_map.shape[:2]

    # Initialize modules
    LN = localNav_Astar(pose_range, coords_range, WH, scene_name)
    LS = localNav_slam(pose_range, coords_range, WH, mark_locs=True, close_small_openings=False, recover_on_collision=False, 
                    fix_thrashing=False, point_cnt=2)
    LS.reset(gt_occ_map)

    semMap_module = SemanticMap(split, scene_name, pose_range, coords_range, WH, ins2cat_dict)
    traverse_lst = []
    action_lst = []
    step_cov_pairs = []

    # Get the area connected to the starting position
    agent_pos = np.array([start_pose[0], scene_height, start_pose[1]])
    if not env.is_navigable(agent_pos):
        print(f'Start pose is not navigable...')
        return None, None

    gt_reached_area = LN.get_start_pose_connected_component(
        (agent_pos[0], -agent_pos[2], 0), gt_occ_map)

    obs_list, pose_list = [], []
    heading_angle = start_pose[2]
    obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
    obs_list.append(obs)
    pose_list.append(pose)

    step = 0
    saved_folder = f'output/calculate_map_stat'
    # Process the trajectory
    for act in actions:
        if act == -1 or act == 0: 
            continue

        print(f'step = {step}')

        pose = pose_list[-1]
        print(f'agent position = {pose[:2]}, angle = {pose[2]}')
        agent_map_pose = (pose[0], -pose[1], -pose[2])
        traverse_lst.append(agent_map_pose)

        # Add observations to the semantic map
        # obs, pose = get_obs_and_pose(env, np.array([x, y, z]), heading)
        semMap_module.build_semantic_map(obs_list, pose_list, step=step, saved_folder=saved_folder, get_depth=True)

        # Compute observed area
        observed_occupancy_map, _, observed_area_flag, _ = semMap_module.get_observed_occupancy_map(agent_map_pose)
        
        step += 1
        action_lst.append(act)

        # output rot is negative of the input angle
        obs_list, pose_list = [], []
        obs, pose = get_obs_and_pose_by_action(env, act)
        obs_list.append(obs)
        pose_list.append(pose)
    
        explored_free_space = np.logical_and(gt_reached_area, observed_area_flag)
        percent = 1. * np.sum(explored_free_space) / np.sum(gt_reached_area)
        area = np.sum(observed_area_flag) * 0.0025  # Each cell is 0.05m * 0.05m
        print(f"Coverage: {percent:.2%}, Explored Area: {area:.2f} m^2")

        step_cov_pairs.append((step, percent, area))

    depth_stat = semMap_module.depth_grid
    
    # save the depth stat
    save_path = f'../bellman_explore/depth_stat/DSS'
    os.makedirs(save_path, exist_ok=True)
    np.save(os.path.join(save_path, f'{scene_name}.npy'), depth_stat)

def main(folder, target_file, best_idx):
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        required=False,
                        default='exp_90degree_ANS_NAVMESH_MAP_1000STEPS.yaml')
    args = parser.parse_args()

    cfg.merge_from_file(f'configs/{args.config}')
    cfg.freeze()

    result = np.load(f"output/{folder}/{target_file}", allow_pickle=True).item()
    actions = result[best_idx]['actions']
    
    #=============================== basic setup =======================================
    split = 'test'
    if cfg.EVAL.SIZE == 'small':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/{split}_scene_floor_dict.npy',
            allow_pickle=True).item()
    elif cfg.EVAL.SIZE == 'large':
        scene_floor_dict = np.load(
            f'{cfg.GENERAL.SCENE_HEIGHTS_DICT_PATH}/large_scale_{split}_scene_floor_dict.npy',
            allow_pickle=True).item()

    env_scene = target_file.split('_')[1]
    #================================ load habitat env============================================
    config = habitat.get_config(config_paths=cfg.GENERAL.DATALOADER_CONFIG_PATH)
    config.defrost()
    config.SIMULATOR.SCENE = f'{cfg.GENERAL.HABITAT_SCENE_DATA_PATH}/mp3d/{env_scene}/{env_scene}.glb'
    config.DATASET.SCENES_DIR = cfg.GENERAL.HABITAT_SCENE_DATA_PATH
    config.freeze()
    env = habitat.sims.make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR)

    env.reset()
    scene_dict = scene_floor_dict[env_scene]

    #=============================== traverse each floor ===========================
    floor_id = int(target_file.split('_')[-1].split('.')[0])
    height = scene_dict[floor_id]['y']
    scene_name = f'{env_scene}_{floor_id}'

    print(f'**********scene_name = {scene_name}***********')

    testing_data = scene_dict[floor_id]['start_pose']
    start_pose = testing_data[best_idx]
    print(f'start_pose = {start_pose}')

    nav_with_trajectory(start_pose, height, actions, split, env, scene_name)

    env.close()

if __name__ == "__main__":
    folder = 'TESTING_RESULTS_90degree_ANS_NAVMESH_MAP_5000STEPS_nearfar_iclr_1.0_3.0'
    files = [f for f in os.listdir(f"output/{folder}") if f.endswith('.npy')]

    with open('output/best_results.json', 'r') as f:
        best_idx = json.load(f)
    for file in tqdm(files):
        print(f'Processing {file}...')
        current_best_idx = best_idx[folder][file]
        main(folder, file, int(current_best_idx))
        time.sleep(5)
        print("=============================================")
