import math
from copy import deepcopy
import torch
import os
import cv2
import json
import torch.nn as nn
import numpy as np
import pprint
from gym import spaces
import habitat
from waypoint_pred.TRM_net import BinaryDistPredictor_TRM
from waypoint_pred.utils import nms
from resnet_encoders import VlnResnetDepthEncoder
from habitat.sims import make_sim
from habitat.utils.geometry_utils import quaternion_rotate_vector, quaternion_from_coeff
from habitat.tasks.utils import cartesian_to_polar
from scipy.spatial.transform import Rotation as R

device = torch.device('cuda')
waypoint_predictor = BinaryDistPredictor_TRM(device=device)
cwp_fn = '../ETPNav/data/wp_pred/check_cwp_bestdist_hfov90'
waypoint_predictor.load_state_dict(torch.load(cwp_fn, map_location = torch.device('cpu'))['predictor']['state_dict'])
for param in waypoint_predictor.parameters():
    param.requires_grad_(False)
waypoint_predictor.to(device)
waypoint_predictor.eval()

observation_space = spaces.Box(
            low=0.0,
            high=1.0,
            shape=(256, 256, 1),
            dtype=np.float32,
        )

depth_encoder = VlnResnetDepthEncoder(
        observation_space,
        output_size=256,
        checkpoint="../VLN-CE/data/ddppo-models/gibson-2plus-resnet50.pth",
        backbone="resnet50",
        spatial_output=False,
    )
depth_encoder.to(device)
depth_encoder.eval()

def angle_feature_torch(headings: torch.Tensor):
    return torch.stack(
        [
            torch.sin(headings),
            torch.cos(headings),
            torch.sin(torch.zeros_like(headings)),
            torch.cos(torch.zeros_like(headings))
        ]
    ).float().T

def get_waypoint_prediction(depths, depth_encoder):
    batch_size = 1
    ''' encoding rgb/depth at all directions ----------------------------- '''
    NUM_ANGLES = 120    # 120 angles 3 degrees each
    NUM_IMGS = 12
    NUM_CLASSES = 12    # 12 distances at each sector
    depth_batch = torch.zeros_like(depths['depth_0']).repeat(NUM_IMGS, 1, 1, 1)

    # reverse the order of input images to clockwise
    a_count = 0
    for i, (k, v) in enumerate(depths.items()):
        if 'depth' in k:  # You might need to double check the keys order
            for bi in range(v.size(0)):
                ra_count = (NUM_IMGS - a_count) % NUM_IMGS
                depth_batch[ra_count + bi*NUM_IMGS] = v[bi]
            a_count += 1
    obs_view12 = {}
    obs_view12['depth'] = depth_batch
    depth_embedding = depth_encoder(obs_view12)  # torch.Size([bs, 128, 4, 4])
    
    ''' waypoint prediction ----------------------------- '''
    waypoint_heatmap_logits = waypoint_predictor(depth_embedding)

    # from heatmap to points
    batch_x_norm = torch.softmax(
        waypoint_heatmap_logits.reshape(
            batch_size, NUM_ANGLES*NUM_CLASSES,
        ), dim=1
    )
    batch_x_norm = batch_x_norm.reshape(
        batch_size, NUM_ANGLES, NUM_CLASSES,
    )
    batch_x_norm_wrap = torch.cat((
        batch_x_norm[:,-1:,:], 
        batch_x_norm, 
        batch_x_norm[:,:1,:]), 
        dim=1)
    batch_output_map = nms(
        batch_x_norm_wrap.unsqueeze(1), 
        max_predictions=5,
        sigma=(7.0,5.0))

    # predicted waypoints before sampling
    batch_output_map = batch_output_map.squeeze(1)[:,1:-1,:]

    return batch_output_map

def get_env(scene):
    config_path = './navigation_graph/config.yaml'
    scene_path = '../scene_datasets/mp3d/{scan}/{scan}.glb'

    config = habitat.get_config(config_path)
    config.defrost()

    config.TASK.SENSORS = []
    config.SIMULATOR.FORWARD_STEP_SIZE = 0.25
    config.SIMULATOR.HABITAT_SIM_V0.ALLOW_SLIDING = False
    config.SIMULATOR.SCENE = scene_path.format(scan=scene)
    sim = make_sim(id_sim=config.SIMULATOR.TYPE, config=config.SIMULATOR)

    return sim

def estimate_cand_pos(pos, ori, ang, dis):
    cand_num = len(ang)
    cand_pos = np.zeros([cand_num, 3])

    ang = np.array(ang)
    dis = np.array(dis)
    ang = (heading_from_quaternion(ori) + ang) % (2 * np.pi)
    cand_pos[:, 0] = pos[0] - dis * np.sin(ang)    # x
    cand_pos[:, 1] = pos[1]                        # y
    cand_pos[:, 2] = pos[2] - dis * np.cos(ang)    # z
    return cand_pos

def heading_from_quaternion(quat: np.array):
    # https://github.com/facebookresearch/habitat-lab/blob/v0.1.7/habitat/tasks/nav/nav.py#L356
    quat = quaternion_from_coeff(quat)
    heading_vector = quaternion_rotate_vector(quat.inverse(), np.array([0, 0, -1]))
    phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1]
    return phi % (2 * np.pi)

def get_cand_pos(wp_pred, position, rotation):
    angle_idxes = wp_pred[0].nonzero()[:, 0]
    distance_idxes = wp_pred[0].nonzero()[:, 1]

    angle_rad_cc = 2*math.pi-angle_idxes.float()/120*2*math.pi
    cand_angle = angle_rad_cc.tolist()
    cand_distance =  ((distance_idxes + 1)*0.25).tolist()

    cand_pos = estimate_cand_pos(position, rotation, cand_angle, cand_distance)
    return cand_pos

def get_depth(position, rotation, sim):
    depths = {}
    save_dir = "../HOV-SG/navigation_graph/frames"
    os.makedirs(save_dir, exist_ok=True)
    for i in range(12):
        delta_yaw = np.radians(i * 30)
        delta_quat = R.from_euler('y', delta_yaw, degrees=False).as_quat()
        new_quat = R.from_quat(rotation) * R.from_quat(delta_quat)
        new_quat_coeffs = new_quat.as_quat()

        obs = sim.get_observations_at(position, new_quat_coeffs)
        rgb = obs['rgb']
        depth = obs['depth']
        depths[f"depth_{i}"] = depth

        # depth_uint8 = (depth * 255).astype(np.uint8)
        # depth_uint8 = cv2.cvtColor(depth_uint8, cv2.COLOR_GRAY2BGR)
        # cv2.imwrite(f"{save_dir}/depth_{i}.png", depth_uint8)
        # cv2.imwrite(f"{save_dir}/rgb_{i}.png", rgb)

    depths = {k: torch.from_numpy(v).unsqueeze(0).float().to('cuda') for k, v in depths.items()}
    return depths

def check_cand(cand_pos, vps, sim):
    if not sim.is_navigable(cand_pos):
        return None

    min_dis = 10000
    min_vp = None
    for vp_id, vp_pos in vps.items():
        dis = ((vp_pos - cand_pos)**2).sum()**0.5
        if dis < min_dis:
            min_dis = dis
            min_vp = vp_id
    
    min_vp = "self" if min_dis > 0.5 else min_vp
    return min_vp

def main():
    target_scene = "8WUmhLawc2A"
    sim = get_env(target_scene)

    vps = {}
    # start_position = [15.068599700927734, 0.17162801325321198, -4.4848198890686035]
    # start_rotation = [0.        , -0.07214095,  0.        ,  0.99739444]
    start_position = [-6.850029945373535, 0.10065700113773346, -4.315700054168701]
    start_rotation = [ 0, 0.868719458580017, 0, -0.495304495096207]
    vps["0"] = start_position
    edges = []

    cands = [("0", start_position, start_rotation)]
    while len(cands) > 0:
        vp_id, position, rotation = cands.pop(0)
        depths = get_depth(position, rotation, sim)
        wp_pred = get_waypoint_prediction(depths, depth_encoder)
        cand_pos = get_cand_pos(wp_pred, position, rotation)
        for cand in cand_pos:
            status = check_cand(cand, vps, sim)
            if status is None:
                continue
            elif status == "self":
                vps[str(len(vps))] = cand.tolist()
                cands.append((str(len(vps)-1), cand, rotation))
                if int(vp_id) < len(vps)-1:
                    edges.append(f"{vp_id}_{len(vps)-1}")
                else:
                    edges.append(f"{len(vps)-1}_{vp_id}")
            else:
                if int(vp_id) < int(status):
                    edges.append(f"{vp_id}_{status}")
                else:
                    edges.append(f"{status}_{vp_id}")
    
    edges = list(set(edges))
    print(f"Number of VPs: {len(vps)}")
    print(f"Number of Edges: {len(edges)}")

    result = {
        "start_position": start_position,
        "start_rotation": start_rotation,
        "vps": vps,
        "edges": edges
    }

    os.makedirs("navigation_graph/wp_graph", exist_ok=True)
    num = 0
    while True:
        save_path = f"navigation_graph/wp_graph/{target_scene}_{num}.json"
        if not os.path.exists(save_path):
            break
        num += 1

    with open(f"navigation_graph/wp_graph/{target_scene}_{num}.json", "w") as f:
        json.dump(result, f, indent=4)  

    sim.close()

if __name__ == "__main__":
    main()
