import numpy as np
import numpy.linalg as LA
import cv2
import matplotlib
import matplotlib.pyplot as plt
import math
from math import cos, sin, acos, atan2, pi, floor, degrees
import random
from .utils.navigation_utils import change_brightness, SimpleRLEnv, get_obs_and_pose, get_obs_and_pose_by_action
from .utils.baseline_utils import apply_color_to_map, pose_to_coords, gen_arrow_head_marker, read_map_npy, read_occ_map_npy, plus_theta_fn
from .utils.nearfar_map_utils_pcd_height import SemanticMap
from .localNavigator_Astar import localNav_Astar
import habitat
import habitat_sim
import random
from core import cfg
from .utils import frontier_utils as fr_utils
from modeling.localNavigator_slam import localNav_slam
from skimage.morphology import skeletonize
from modeling.utils.UNet import UNet
import torch
from collections import OrderedDict
from ANS_modeling.model import compute_long_term_goal
from ANS_modeling.model import RL_Policy

def nav_ANS(split, env, episode_id, scene_name, scene_height, start_pose, saved_folder, device):
	"""Major function for navigation.
	
	Takes in initialized habitat environment and start location.
	Start exploring the environment.
	Explore detected frontiers.
	Use local navigator to reach the frontiers.
	When reach the limited number of steps, compute the explored area and return the numbers.
	"""

	act_dict = {-1: 'Done', 0: 'stop', 1: 'forward', 2: 'left', 3:'right'}

	#============================ get scene ins to cat dict
	scene = env.semantic_annotations()
	ins2cat_dict = {
		int(obj.id.split("_")[-1]): obj.category.index()
		for obj in scene.objects
	}

	#=================================== start original navigation code ========================
	np.random.seed(cfg.GENERAL.RANDOM_SEED)
	random.seed(cfg.GENERAL.RANDOM_SEED)

	if cfg.NAVI.GT_OCC_MAP_TYPE == 'NAV_MESH':
		if cfg.EVAL.SIZE == 'small':
			occ_map_npy = np.load(
				f'output/semantic_map/{split}/{scene_name}/BEV_occupancy_map.npy',
				allow_pickle=True).item()
		elif cfg.EVAL.SIZE == 'large':
			occ_map_npy = np.load(
				f'output/large_scale_semantic_map/{scene_name}/BEV_occupancy_map.npy',
				allow_pickle=True).item()
	gt_occ_map, pose_range, coords_range, WH = read_occ_map_npy(occ_map_npy)
	H, W = gt_occ_map.shape[:2]
	# for computing gt skeleton
	if cfg.NAVI.D_type == 'Skeleton':
		skeleton = skeletonize(gt_occ_map)
		if cfg.NAVI.PRUNE_SKELETON:
			skeleton = fr_utils.prune_skeleton(gt_occ_map, skeleton)

	#===================================== load modules ==========================================
	device = torch.device('cpu')
	g_policy = RL_Policy((8, 240, 240), device, 
		base_kwargs={'recurrent': 0,
		'hidden_size': 256,
		'downscaling': 2
		}).to(device)

	state_dict = torch.load('trained_weights/model_best.global', map_location=lambda storage, loc: storage)
	g_policy.load_state_dict(state_dict)
	g_policy.eval()

	LN = localNav_Astar(pose_range, coords_range, WH, scene_name)

	LS = localNav_slam(pose_range, coords_range, WH, mark_locs=True, close_small_openings=False, recover_on_collision=False, 
	fix_thrashing=False, point_cnt=2)
	LS.reset(gt_occ_map)

	semMap_module_near = SemanticMap(split, scene_name, pose_range, coords_range, WH,
								ins2cat_dict, type='near')  # build the observed sem map
	semMap_module_far = SemanticMap(split, scene_name, pose_range, coords_range, WH,
								ins2cat_dict, type='far')  # build the observed sem map
	
	traverse_lst = []
	traverse_coords_list = []
	action_lst = []
	step_cov_pairs = []

	#===================================== setup the start location ===============================#

	agent_pos = np.array([start_pose[0], scene_height, start_pose[1]])  # (6.6, -6.9), (3.6, -4.5)
	# check if the start point is navigable
	if not env.is_navigable(agent_pos):
		print(f'start pose is not navigable ...')
		assert 1 == 2

	# ================= get the area connected with start pose ==============
	gt_reached_area = LN.get_start_pose_connected_component((agent_pos[0], -agent_pos[2], 0), gt_occ_map)

	if cfg.NAVI.HFOV == 90:
		obs_list, pose_list = [], []
		heading_angle = start_pose[2]
		obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
		obs_list.append(obs)
		pose_list.append(pose)
	elif cfg.NAVI.HFOV == 360:
		obs_list, pose_list = [], []
		for rot in [90, 180, 270, 0]:
			heading_angle = rot / 180 * np.pi
			heading_angle = plus_theta_fn(heading_angle, start_pose[2])
			obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
			obs_list.append(obs)
			pose_list.append(pose)

	step = 0
	subgoal_coords = None
	subgoal_pose = None
	MODE_FIND_SUBGOAL = True
	explore_steps = 0
	MODE_FIND_GOAL = False
	reached_subgoals_near = []
	reached_subgoals_far = []

	far_num = 0
	near_num = 0

	while step < cfg.NAVI.NUM_STEPS:
		# print(f'step = {step}')

		#=============================== get agent global pose on habitat env ========================#
		pose = pose_list[-1]
		# print(f'agent position = {pose[:2]}, angle = {pose[2]}')
		agent_map_pose = (pose[0], -pose[1], -pose[2])
		agent_map_coords = pose_to_coords(agent_map_pose, pose_range, coords_range, WH)
		traverse_lst.append(agent_map_pose)
		traverse_coords_list.append(agent_map_coords)

		# add the observed area
		semMap_module_near.build_semantic_map(obs_list,
											pose_list,
											step=step,
											saved_folder=saved_folder)
		
		semMap_module_far.build_semantic_map(obs_list,
											pose_list,
											step=step,
											saved_folder=saved_folder)

		if MODE_FIND_SUBGOAL:
			observed_occupancy_map_near, _, observed_area_flag_near, built_semantic_map_near = semMap_module_near.get_observed_occupancy_map(agent_map_pose
			)

			long_term_goal_coords_near = compute_long_term_goal(observed_occupancy_map_near, agent_map_coords, \
				agent_map_pose[2], traverse_coords_list, g_policy, device)

			frontiers_near = fr_utils.get_frontiers(observed_occupancy_map_near)
			frontiers_near, _ = LN.filter_unreachable_frontiers(frontiers_near, agent_map_pose, observed_occupancy_map_near)
			for reached_subgoal in reached_subgoals_near:
				if reached_subgoal in frontiers_near:
					frontiers_near.remove(reached_subgoal)
			chosen_frontier_near = fr_utils.get_the_nearest_frontier_to_the_long_term_goal(frontiers_near, long_term_goal_coords_near)

			observed_occupancy_map_far, _, observed_area_flag_far, built_semantic_map_far = semMap_module_far.get_observed_occupancy_map(agent_map_pose
			)

			# long_term_goal_coords_far = compute_long_term_goal(observed_occupancy_map_far, agent_map_coords, \
			# 	agent_map_pose[2], traverse_coords_list, g_policy, device)

			frontiers_far = fr_utils.get_frontiers(observed_occupancy_map_far)
			frontiers_far, _ = LN.filter_unreachable_frontiers(frontiers_far, agent_map_pose, observed_occupancy_map_far)
			for reached_subgoal in reached_subgoals_far:
				if reached_subgoal in frontiers_far:
					frontiers_far.remove(reached_subgoal)
			# chosen_frontier_far = fr_utils.get_the_nearest_frontier_to_the_long_term_goal(frontiers_far, long_term_goal_coords_far)
			
			min_L = 10000000
			min_frontier = None
			for fron in frontiers_far:
				fron_centroid_coords = (int(fron.centroid[1]),
										int(fron.centroid[0]))
				if observed_area_flag_near[fron_centroid_coords[1], fron_centroid_coords[0]]:
					L = fr_utils._eucl_dist(fron_centroid_coords, agent_map_coords)
					if L < min_L:
						min_L = L
						min_frontier = fron 
					elif L == min_L and hash(fron) > hash(min_frontier):
						min_L = L
						min_frontier = fron
			chosen_frontier_far = min_frontier

			# choose between near and far frontiers
			if chosen_frontier_near is None and chosen_frontier_far is None:
				chosen_frontier = None
			elif chosen_frontier_near is None:
				chosen_frontier = chosen_frontier_far
				# print(f'chosen frontier is far')
				far_num += 1
			elif chosen_frontier_far is None:
				chosen_frontier = chosen_frontier_near
				# print(f'chosen frontier is near')
				near_num += 1
			else:
				frontier_far_coords = (int(chosen_frontier_far.centroid[1]), int(chosen_frontier_far.centroid[0]))
				if observed_area_flag_near[frontier_far_coords[1], frontier_far_coords[0]]:
					chosen_frontier = chosen_frontier_far
					# print(f'chosen frontier is far')
					far_num += 1
				else:
					print("ERROR!!!")
					exit(0)

			MODE_FIND_SUBGOAL= False

		#===================================== check if exploration is done ========================
		if chosen_frontier is None:
			# print('There are no more frontiers to explore. Stop navigation.')
			break

		#====================================== take next action ================================
		act, act_seq, subgoal_coords, subgoal_pose = LS.plan_to_reach_frontier(agent_map_pose, chosen_frontier, 
			observed_occupancy_map_near)
		# print(f'subgoal_coords = {subgoal_coords}')
		# print(f'agent_map_coords = {agent_map_coords}')
		# print(f'action = {act_dict[act]}')
		action_lst.append(act)
		
		if act == -1 or act == 0: # finished navigating to the subgoal
			step += 1
			# print(f'reached the subgoal')
			MODE_FIND_SUBGOAL = True
			if chosen_frontier == chosen_frontier_far:
				reached_subgoals_far.append(chosen_frontier)
			elif chosen_frontier == chosen_frontier_near:
				reached_subgoals_near.append(chosen_frontier)
			else:
				print("ERROR!!!")
				exit(0)
		else:
			step += 1
			explore_steps += 1
			# output rot is negative of the input angle
			if cfg.NAVI.HFOV == 90:
				obs_list, pose_list = [], []
				obs, pose = get_obs_and_pose_by_action(env, act)
				obs_list.append(obs)
				pose_list.append(pose)
			elif cfg.NAVI.HFOV == 360:
				obs_list, pose_list = [], []
				obs, pose = get_obs_and_pose_by_action(env, act)
				next_pose = pose
				agent_pos = np.array([next_pose[0], scene_height, next_pose[1]])
				for rot in [90, 180, 270, 0]:
					heading_angle = rot / 180 * np.pi
					heading_angle = plus_theta_fn(heading_angle, -next_pose[2])
					obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
					obs_list.append(obs)
					pose_list.append(pose)

		#=================== compute percent and explored area ===============================
		# percent is computed on the free space
		explored_free_space = np.logical_and(gt_reached_area, observed_area_flag_near)
		percent = 1. * np.sum(explored_free_space) / np.sum(gt_reached_area)

		# explored area
		area = np.sum(observed_area_flag_near) * .0025
		# print(f'step = {step}, percent = {percent}, area = {area} meter^2')
		step_cov_pairs.append((step, percent, area))

		if explore_steps == cfg.NAVI.NUM_STEPS_EXPLORE:
			explore_steps = 0
			MODE_FIND_SUBGOAL = True

	#============================================ Finish exploration =============================================
	if cfg.NAVI.FLAG_VISUALIZE_FINAL_TRAJ:
		#==================================== visualize the path on the map ==============================
		built_semantic_map, observed_area_flag, _ = semMap_module_near.get_semantic_map(
		)

		color_built_semantic_map = apply_color_to_map(built_semantic_map)
		color_built_semantic_map = change_brightness(color_built_semantic_map,
														observed_area_flag,
														value=60)

		#=================================== visualize the agent pose as red nodes =======================
		x_coord_lst, z_coord_lst, theta_lst = [], [], []
		for cur_pose in traverse_lst:
			x_coord, z_coord = pose_to_coords((cur_pose[0], cur_pose[1]),
											pose_range, coords_range, WH)
			x_coord_lst.append(x_coord)
			z_coord_lst.append(z_coord)
			theta_lst.append(cur_pose[2])

		#'''
		fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(25, 10))
		ax[0].imshow(observed_occupancy_map_near, cmap='gray')
		marker, scale = gen_arrow_head_marker(theta_lst[-1])
		ax[0].scatter(x_coord_lst[-1],
						z_coord_lst[-1],
						marker=marker,
						s=(30 * scale)**2,
						c='red',
						zorder=5)
		ax[0].scatter(x_coord_lst[0],
						z_coord_lst[0],
						marker='s',
						s=50,
						c='red',
						zorder=5)
		#ax.plot(x_coord_lst, z_coord_lst, lw=5, c='blue', zorder=3)
		ax[0].scatter(x_coord_lst, 
					z_coord_lst, 
					c=range(len(x_coord_lst)), 
					cmap='viridis', 
					s=np.linspace(5, 2, num=len(x_coord_lst))**2, 
					zorder=3)
		ax[0].get_xaxis().set_visible(False)
		ax[0].get_yaxis().set_visible(False)
		ax[0].set_title('improved observed_occ_map + frontiers')

		ax[1].imshow(color_built_semantic_map)
		ax[1].get_xaxis().set_visible(False)
		ax[1].get_yaxis().set_visible(False)
		ax[1].set_title('built semantic map')

		fig.tight_layout()
		plt.title('observed area')
		#plt.show()
		fig.savefig(f'{saved_folder}/final_semmap.jpg')
		plt.close()
		#assert 1==2
		#'''

	#====================================== compute statistics =================================
	if cfg.NAVI.PERCEPTION == 'UNet_Potential':
		del unet_model
		del checkpoint

	_, observed_area_flag, _ = semMap_module_near.get_semantic_map()
	explored_free_area_flag = np.logical_and(gt_occ_map, observed_area_flag)

	sum_gt_free_area = np.sum(gt_occ_map > 0)
	sum_explored_free_area = np.sum(explored_free_area_flag > 0)
	percent = sum_explored_free_area * 1. / sum_gt_free_area
	# print(f'********percent = {percent}, step = {step}')

	step_cov_pairs = np.array(step_cov_pairs).astype('float32')

	# print(f'far_num = {far_num}, near_num = {near_num}')

	return percent, step, traverse_lst, action_lst, step_cov_pairs
