import numpy as np
import numpy.linalg as LA
import cv2
import matplotlib
import matplotlib.pyplot as plt
import math
from math import cos, sin, acos, atan2, pi, floor, degrees
import random
from .utils.navigation_utils import change_brightness, SimpleRLEnv, get_obs_and_pose, get_obs_and_pose_by_action
from .utils.baseline_utils import apply_color_to_map, pose_to_coords, gen_arrow_head_marker, read_map_npy, read_occ_map_npy, plus_theta_fn
from .utils.nearfar_map_utils_pcd_height import SemanticMap
from .localNavigator_Astar import localNav_Astar
import habitat
import habitat_sim
import random
from core import cfg
from .utils import frontier_utils as fr_utils
from modeling.localNavigator_slam import localNav_slam
from skimage.morphology import skeletonize
from modeling.utils.UNet import UNet
import torch
from collections import OrderedDict

def nav_DP(split, env, episode_id, scene_name, scene_height, start_pose, saved_folder, device):
	"""Major function for navigation.
	
	Takes in initialized habitat environment and start location.
	Start exploring the environment.
	Explore detected frontiers.
	Use local navigator to reach the frontiers.
	When reach the limited number of steps, compute the explored area and return the numbers.
	"""

	act_dict = {-1: 'Done', 0: 'stop', 1: 'forward', 2: 'left', 3:'right'}

	#============================ get scene ins to cat dict
	scene = env.semantic_annotations()
	ins2cat_dict = {
		int(obj.id.split("_")[-1]): obj.category.index()
		for obj in scene.objects
	}

	#=================================== start original navigation code ========================
	np.random.seed(cfg.GENERAL.RANDOM_SEED)
	random.seed(cfg.GENERAL.RANDOM_SEED)

	if cfg.NAVI.GT_OCC_MAP_TYPE == 'NAV_MESH':
		if cfg.EVAL.SIZE == 'small':
			occ_map_npy = np.load(
				f'output/semantic_map/{split}/{scene_name}/BEV_occupancy_map.npy',
				allow_pickle=True).item()
		elif cfg.EVAL.SIZE == 'large':
			occ_map_npy = np.load(
				f'output/large_scale_semantic_map/{scene_name}/BEV_occupancy_map.npy',
				allow_pickle=True).item()
	gt_occ_map, pose_range, coords_range, WH = read_occ_map_npy(occ_map_npy)
	H, W = gt_occ_map.shape[:2]
	# for computing gt skeleton
	if cfg.NAVI.D_type == 'Skeleton':
		skeleton = skeletonize(gt_occ_map)
		if cfg.NAVI.PRUNE_SKELETON:
			skeleton = fr_utils.prune_skeleton(gt_occ_map, skeleton)

	#===================================== load modules ==========================================
	if cfg.NAVI.PERCEPTION == 'UNet_Potential':
		unet_model = UNet(n_channel_in=cfg.PRED.PARTIAL_MAP.INPUT_CHANNEL, n_class_out=cfg.PRED.PARTIAL_MAP.OUTPUT_CHANNEL).to(device)
		if cfg.NAVI.STRATEGY == 'Greedy':
			if cfg.PRED.PARTIAL_MAP.INPUT == 'occ_and_sem':
				checkpoint = torch.load(f'{cfg.PRED.PARTIAL_MAP.SAVED_FOLDER}/{cfg.PRED.PARTIAL_MAP.INPUT}/best_checkpoint.pth.tar', map_location=device)
			#unet_model.load_state_dict(checkpoint['state_dict'])
		elif cfg.NAVI.STRATEGY == 'DP':
			if cfg.PRED.PARTIAL_MAP.INPUT == 'occ_and_sem':
				checkpoint = torch.load(f'{cfg.PRED.PARTIAL_MAP.SAVED_FOLDER}/{cfg.PRED.PARTIAL_MAP.INPUT}/best_checkpoint.pth.tar', map_location=device)
			elif cfg.PRED.PARTIAL_MAP.INPUT == 'occ_only':
				checkpoint = torch.load(f'run/MP3D/unet/experiment_5/checkpoint.pth.tar', map_location=device)
		
		new_state_dict = OrderedDict()
		for k, v in checkpoint['state_dict'].items():
			name = k[7:] #remove 'module'
			new_state_dict[name] = v
		unet_model.load_state_dict(new_state_dict)
		
		unet_model.eval()

	LN = localNav_Astar(pose_range, coords_range, WH, scene_name)

	LS = localNav_slam(pose_range, coords_range, WH, mark_locs=True, close_small_openings=False, recover_on_collision=False, 
	fix_thrashing=False, point_cnt=2)
	LS.reset(gt_occ_map)

	semMap_module_near = SemanticMap(split, scene_name, pose_range, coords_range, WH,
								ins2cat_dict, type='near')  # build the observed sem map
	semMap_module_far = SemanticMap(split, scene_name, pose_range, coords_range, WH,
								ins2cat_dict, type='far')  # build the observed sem map
	traverse_lst = []
	action_lst = []
	step_cov_pairs = []

	#===================================== setup the start location ===============================#

	agent_pos = np.array([start_pose[0], scene_height,
						  start_pose[1]])  # (6.6, -6.9), (3.6, -4.5)
	# check if the start point is navigable
	if not env.is_navigable(agent_pos):
		print(f'start pose is not navigable ...')
		assert 1 == 2

	# ================= get the area connected with start pose ==============
	gt_reached_area = LN.get_start_pose_connected_component((agent_pos[0], -agent_pos[2], 0), gt_occ_map)

	if cfg.NAVI.HFOV == 90:
		obs_list, pose_list = [], []
		heading_angle = start_pose[2]
		obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
		obs_list.append(obs)
		pose_list.append(pose)
	elif cfg.NAVI.HFOV == 360:
		obs_list, pose_list = [], []
		for rot in [90, 180, 270, 0]:
			heading_angle = rot / 180 * np.pi
			heading_angle = plus_theta_fn(heading_angle, start_pose[2])
			obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
			obs_list.append(obs)
			pose_list.append(pose)

	step = 0
	subgoal_coords = None
	subgoal_pose = None
	MODE_FIND_SUBGOAL = True
	explore_steps = 0
	MODE_FIND_GOAL = False
	visited_frontier = set()
	chosen_frontier_near = None
	chosen_frontier_far = None
	old_frontiers_near = None
	old_frontiers_far = None
	frontiers_near = None
	frontiers_far = None

	far_num = 0
	near_num = 0

	while step < cfg.NAVI.NUM_STEPS:
		print(f'step = {step}')

		#=============================== get agent global pose on habitat env ========================#
		pose = pose_list[-1]
		print(f'agent position = {pose[:2]}, angle = {pose[2]}')
		agent_map_pose = (pose[0], -pose[1], -pose[2])
		traverse_lst.append(agent_map_pose)

		# add the observed area
		semMap_module_near.build_semantic_map(obs_list,
											pose_list,
											step=step,
											saved_folder=saved_folder)
		
		semMap_module_far.build_semantic_map(obs_list,
											pose_list,
											step=step,
											saved_folder=saved_folder)

		if MODE_FIND_SUBGOAL:

			# get near frontier
			observed_occupancy_map_near, gt_occupancy_map_near, observed_area_flag_near, built_semantic_map_near = semMap_module_near.get_observed_occupancy_map(agent_map_pose
			)

			if frontiers_near is not None:
				old_frontiers_near = frontiers_near

			frontiers_near = fr_utils.get_frontiers(observed_occupancy_map_near)
			frontiers_near = frontiers_near - visited_frontier

			frontiers_near, dist_occupancy_map_near = LN.filter_unreachable_frontiers(
				frontiers_near, agent_map_pose, observed_occupancy_map_near)

			if cfg.NAVI.PERCEPTION == 'UNet_Potential':
				frontiers_near = fr_utils.compute_frontier_potential(frontiers_near, observed_occupancy_map_near, gt_occupancy_map_near, 
					observed_area_flag_near, built_semantic_map_near, None, unet_model, device, LN, agent_map_pose)
			elif cfg.NAVI.PERCEPTION == 'Potential':
				if cfg.NAVI.D_type == 'Skeleton':
					frontiers_near = fr_utils.compute_frontier_potential(frontiers_near, observed_occupancy_map_near, gt_occupancy_map_near, 
						observed_area_flag_near, built_semantic_map_near, skeleton)
				else:
					frontiers_near = fr_utils.compute_frontier_potential(frontiers_near, observed_occupancy_map_near, gt_occupancy_map_near, 
						observed_area_flag_near, built_semantic_map_near, None)

			if old_frontiers_near is not None:
				frontiers_near = fr_utils.update_frontier_set(old_frontiers_near, frontiers_near, max_dist=5, chosen_frontier=chosen_frontier_near)

			top_frontiers_near = fr_utils.select_top_frontiers(frontiers_near, top_n=6)
			chosen_frontier_near = fr_utils.get_frontier_with_DP(top_frontiers_near, agent_map_pose, dist_occupancy_map_near, \
				cfg.NAVI.NUM_STEPS-step, LN)

			# get far frontier
			observed_occupancy_map_far, gt_occupancy_map_far, observed_area_flag_far, built_semantic_map_far = semMap_module_far.get_observed_occupancy_map(agent_map_pose
			)

			if frontiers_far is not None:
				old_frontiers_far = frontiers_far

			frontiers_far = fr_utils.get_frontiers(observed_occupancy_map_far)
			frontiers_far = frontiers_far - visited_frontier

			frontiers_far, dist_occupancy_map_far = LN.filter_unreachable_frontiers(
				frontiers_far, agent_map_pose, observed_occupancy_map_far)

			if cfg.NAVI.PERCEPTION == 'UNet_Potential':
				frontiers_far = fr_utils.compute_frontier_potential(frontiers_far, observed_occupancy_map_far, gt_occupancy_map_far, 
					observed_area_flag_far, built_semantic_map_far, None, unet_model, device, LN, agent_map_pose)
			elif cfg.NAVI.PERCEPTION == 'Potential':
				if cfg.NAVI.D_type == 'Skeleton':
					frontiers_far = fr_utils.compute_frontier_potential(frontiers_far, observed_occupancy_map_far, gt_occupancy_map_far, 
						observed_area_flag_far, built_semantic_map_far, skeleton)
				else:
					frontiers_far = fr_utils.compute_frontier_potential(frontiers_far, observed_occupancy_map_far, gt_occupancy_map_far, 
						observed_area_flag_far, built_semantic_map_far, None)

			if old_frontiers_far is not None:
				frontiers_far = fr_utils.update_frontier_set(old_frontiers_far, frontiers_far, max_dist=5, chosen_frontier=chosen_frontier_far)

			top_frontiers_far = fr_utils.select_top_frontiers(frontiers_far, top_n=6)
			chosen_frontier_far = fr_utils.get_frontier_with_DP(top_frontiers_far, agent_map_pose, dist_occupancy_map_far, \
				cfg.NAVI.NUM_STEPS-step, LN)
			
			# choose between near and far frontiers
			if chosen_frontier_near is None and chosen_frontier_far is None:
				chosen_frontier = None
			elif chosen_frontier_near is None:
				chosen_frontier = chosen_frontier_far
				print(f'chosen frontier is far')
				far_num += 1
			elif chosen_frontier_far is None:
				chosen_frontier = chosen_frontier_near
				print(f'chosen frontier is near')
				near_num += 1
			else:
				frontier_far_coords = (int(chosen_frontier_far.centroid[1]), int(chosen_frontier_far.centroid[0]))
				if observed_area_flag_near[frontier_far_coords[0], frontier_far_coords[1]]:
					chosen_frontier = chosen_frontier_far
					print(f'chosen frontier is far')
					far_num += 1
				else:
					chosen_frontier = chosen_frontier_near
					print(f'chosen frontier is near')
					near_num += 1
			
		#===================================== check if exploration is done ========================
		if chosen_frontier is None:
			print('There are no more frontiers to explore. Stop navigation.')
			break

		#==================================== update particle filter =============================
		if MODE_FIND_SUBGOAL:
			MODE_FIND_SUBGOAL = False
			explore_steps = 0

		#====================================== take next action ================================
		act, act_seq, subgoal_coords, subgoal_pose = LS.plan_to_reach_frontier(agent_map_pose, chosen_frontier, 
			observed_occupancy_map_near)
		print(f'subgoal_coords = {subgoal_coords}')
		print(f'action = {act_dict[act]}')
		action_lst.append(act)
		
		if act == -1 or act == 0: # finished navigating to the subgoal
			print(f'reached the subgoal')
			MODE_FIND_SUBGOAL = True
			visited_frontier.add(chosen_frontier)
		else:
			step += 1
			explore_steps += 1
			# output rot is negative of the input angle
			if cfg.NAVI.HFOV == 90:
				obs_list, pose_list = [], []
				obs, pose = get_obs_and_pose_by_action(env, act)
				obs_list.append(obs)
				pose_list.append(pose)
			elif cfg.NAVI.HFOV == 360:
				obs_list, pose_list = [], []
				obs, pose = get_obs_and_pose_by_action(env, act)
				next_pose = pose
				agent_pos = np.array([next_pose[0], scene_height, next_pose[1]])
				for rot in [90, 180, 270, 0]:
					heading_angle = rot / 180 * np.pi
					heading_angle = plus_theta_fn(heading_angle, -next_pose[2])
					obs, pose = get_obs_and_pose(env, agent_pos, heading_angle)
					obs_list.append(obs)
					pose_list.append(pose)

		#=================== compute percent and explored area ===============================
		# percent is computed on the free space
		explored_free_space = np.logical_and(gt_reached_area, observed_area_flag_near)
		percent = 1. * np.sum(explored_free_space) / np.sum(gt_reached_area)

		# explored area
		area = np.sum(observed_area_flag_near) * .0025
		print(f'step = {step}, percent = {percent}, area = {area} meter^2')
		step_cov_pairs.append((step, percent, area))

		if explore_steps == cfg.NAVI.NUM_STEPS_EXPLORE:
			explore_steps = 0
			MODE_FIND_SUBGOAL = True

	#============================================ Finish exploration =============================================
	if cfg.NAVI.FLAG_VISUALIZE_FINAL_TRAJ:
		#==================================== visualize the path on the map ==============================
		built_semantic_map, observed_area_flag, _ = semMap_module_near.get_semantic_map(
		)

		color_built_semantic_map = apply_color_to_map(built_semantic_map)
		color_built_semantic_map = change_brightness(color_built_semantic_map,
													 observed_area_flag,
													 value=60)

		#=================================== visualize the agent pose as red nodes =======================
		x_coord_lst, z_coord_lst, theta_lst = [], [], []
		for cur_pose in traverse_lst:
			x_coord, z_coord = pose_to_coords((cur_pose[0], cur_pose[1]),
											  pose_range, coords_range, WH)
			x_coord_lst.append(x_coord)
			z_coord_lst.append(z_coord)
			theta_lst.append(cur_pose[2])

		#'''
		fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(25, 10))
		ax[0].imshow(observed_occupancy_map_near, cmap='gray')
		marker, scale = gen_arrow_head_marker(theta_lst[-1])
		ax[0].scatter(x_coord_lst[-1],
					  z_coord_lst[-1],
					  marker=marker,
					  s=(30 * scale)**2,
					  c='red',
					  zorder=5)
		ax[0].scatter(x_coord_lst[0],
					  z_coord_lst[0],
					  marker='s',
					  s=50,
					  c='red',
					  zorder=5)
		#ax.plot(x_coord_lst, z_coord_lst, lw=5, c='blue', zorder=3)
		ax[0].scatter(x_coord_lst, 
				   z_coord_lst, 
				   c=range(len(x_coord_lst)), 
				   cmap='viridis', 
				   s=np.linspace(5, 2, num=len(x_coord_lst))**2, 
				   zorder=3)
		ax[0].get_xaxis().set_visible(False)
		ax[0].get_yaxis().set_visible(False)
		ax[0].set_title('improved observed_occ_map + frontiers')

		ax[1].imshow(color_built_semantic_map)
		ax[1].get_xaxis().set_visible(False)
		ax[1].get_yaxis().set_visible(False)
		ax[1].set_title('built semantic map')

		fig.tight_layout()
		plt.title('observed area')
		#plt.show()
		fig.savefig(f'{saved_folder}/final_semmap.jpg')
		plt.close()
		#assert 1==2
		#'''

	#====================================== compute statistics =================================
	if cfg.NAVI.PERCEPTION == 'UNet_Potential':
		del unet_model
		del checkpoint

	_, observed_area_flag, _ = semMap_module_near.get_semantic_map()
	explored_free_area_flag = np.logical_and(gt_occ_map, observed_area_flag)

	sum_gt_free_area = np.sum(gt_occ_map > 0)
	sum_explored_free_area = np.sum(explored_free_area_flag > 0)
	percent = sum_explored_free_area * 1. / sum_gt_free_area
	print(f'********percent = {percent}, step = {step}')

	step_cov_pairs = np.array(step_cov_pairs).astype('float32')

	print(f'far_num = {far_num}, near_num = {near_num}')

	return percent, step, traverse_lst, action_lst, step_cov_pairs
