import os
import networkx as nx
import numpy as np
import json
import habitat_sim
from hovsg.data.hm3dsem.habitat_utils import make_cfg_mp3d
from tqdm import tqdm

def get_sim(scan):
    scene_name = scan
    root_dataset_dir = "../scene_datasets/mp3d"
    scene_data_dir = f"{root_dataset_dir}/{scene_name}/"
    save_dir = f"metric/mp3d_views/{scene_name}"

    scene_mesh = os.path.join(scene_data_dir, scene_name + ".glb")

    sim_settings = {
        "scene": scene_mesh,
        "default_agent": 0,
        "sensor_height": 1.5,
        "color_sensor": True,
        "depth_sensor": True,
        "semantic_sensor": True,
        "lidar_sensor": False,
        "move_forward": 0.2,
        "move_backward": 0.2,
        "turn_left": 5,
        "turn_right": 5,
        "look_up": 5,
        "look_down": 5,
        "look_left": 5,
        "look_right": 5,
        "width": 1080,
        "height": 720,
        "enable_physics": False,
        "seed": 42,
        "lidar_fov": 360,
        "depth_img_for_lidar_n": 20,
        "img_save_dir": save_dir,
    }
    os.environ["MAGNUM_LOG"] = "quiet"
    os.environ["HABITAT_SIM_LOG"] = "quiet"

    sim_cfg = make_cfg_mp3d(sim_settings, root_dataset_dir, scene_data_dir, scene_name, print_scene=False)
    sim = habitat_sim.Simulator(sim_cfg)

    return sim

def load_nav_graph(connectivity_dir, scan):
    ''' Load connectivity graph for each scan '''

    def distance(pose1, pose2):
        ''' Euclidean distance between two graph poses '''
        return ((pose1['pose'][3]-pose2['pose'][3])**2\
            + (pose1['pose'][7]-pose2['pose'][7])**2\
            + (pose1['pose'][11]-pose2['pose'][11])**2)**0.5

    with open(os.path.join(connectivity_dir, '%s_connectivity.json' % scan)) as f:
        G = nx.Graph()
        positions = {}
        data = json.load(f)
        for i,item in enumerate(data):
            if item['included']:
                for j,conn in enumerate(item['unobstructed']):
                    if conn and data[j]['included']:
                        positions[item['image_id']] = np.array([item['pose'][3],
                                item['pose'][7], item['pose'][11]])
                        assert data[j]['unobstructed'][i], 'Graph should be undirected'
                        G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j]))
        nx.set_node_attributes(G, values=positions, name='position')
    return G

with open("../VLN-DUET/datasets/R2R/connectivity/scans.txt", "r") as f:
    scans = f.readlines()
scans = [scan.strip() for scan in scans]

connectivity_dir = f"../VLN-DUET/datasets/R2R/connectivity"
for scan in tqdm(scans, desc="Processing scans", position=0):
    G = load_nav_graph(connectivity_dir, scan)
    sim = get_sim(scan)

    vp2pos = {}
    for node in tqdm(G.nodes, desc="Processing nodes", position=1, leave=False, total=len(G.nodes)):
        pos = G.nodes[node]['position']
        x = pos[0]
        y = pos[2]-1.5
        z = -pos[1]
        # Guess a y to construct a 3D point
        test_point = np.array([x, y, z])

        # Snap this point to the navmesh
        pathfinder = sim.pathfinder
        new_point = pathfinder.snap_point(test_point)

        vp2pos[node] = [new_point[0], new_point[1], new_point[2]]

    os.makedirs("vp2pos", exist_ok=True)
    with open(f"vp2pos/vp2pos_{scan}.json", "w") as f:
        json.dump(vp2pos, f, indent=4)
    
    sim.close()