import os
import cv2
import json
import math
import yaml
import networkx as nx
import numpy as np
import numpy as np
import habitat
from tqdm import tqdm
import matplotlib.pyplot as plt
from itertools import product

def load_nav_graph(connectivity_dir, scan):
    ''' Load connectivity graph for each scan '''

    def distance(pose1, pose2):
        ''' Euclidean distance between two graph poses '''
        return ((pose1['pose'][3]-pose2['pose'][3])**2\
            + (pose1['pose'][7]-pose2['pose'][7])**2\
            + (pose1['pose'][11]-pose2['pose'][11])**2)**0.5

    with open(os.path.join(connectivity_dir, '%s_connectivity.json' % scan)) as f:
        G = nx.Graph()
        positions = {}
        data = json.load(f)
        for i,item in enumerate(data):
            if item['included']:
                for j,conn in enumerate(item['unobstructed']):
                    if conn and data[j]['included']:
                        positions[item['image_id']] = np.array([item['pose'][3],
                                item['pose'][7], item['pose'][11]])
                        assert data[j]['unobstructed'][i], 'Graph should be undirected'
                        G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j]))
        nx.set_node_attributes(G, values=positions, name='position')
    
    return G

def pose_to_coords(cur_pose,
                    map_data,
                    cell_size=0.05,
                    flag_cropped=True):
    """
    convert pose (X, Z) in the habitat environment to the cell location 'coords' on the map.
    """
    tx, tz = cur_pose[:2]
    tz = -tz

    pose_range = map_data['pose_range']
    coords_range = map_data['coords_range']
    wh = map_data['wh']

    if flag_cropped:
        x_coord = math.floor((tx - pose_range[0]) / cell_size - coords_range[0])
        z_coord = math.floor((wh[0] - (tz - pose_range[1]) / cell_size) -
                        coords_range[1])
    else:
        x_coord = math.floor((tx - pose_range[0]) / cell_size)
        z_coord = math.floor(wh[0] - (tz - pose_range[1]) / cell_size)

    if len(cur_pose) == 3:
        yaw = cur_pose[2]
        map_yaw = -yaw
        return (x_coord, z_coord, map_yaw)
    else:
        return (x_coord, z_coord)

def read_sem_map_npy(map_npy):
    """ read saved semantic map numpy file infomation."""
    min_x = map_npy['min_x']
    max_x = map_npy['max_x']
    min_z = map_npy['min_z']
    max_z = map_npy['max_z']
    min_X = map_npy['min_X']
    max_X = map_npy['max_X']
    min_Z = map_npy['min_Z']
    max_Z = map_npy['max_Z']
    W = map_npy['W']
    H = map_npy['H']
    semantic_map = map_npy['semantic_map']
    map_data = {}
    map_data['semantic_map'] = semantic_map
    map_data['pose_range'] = (min_X, min_Z, max_X, max_Z)
    map_data['coords_range'] = (min_x, min_z, max_x, max_z)
    map_data['wh'] = (W, H)
    return map_data

def get_occ_map(scene, height):
    cell_size = 0.05

    # ============================= initialize a grid =========================================
    x = np.arange(-50.0, 50.0, cell_size)
    z = np.arange(-50.0, 50.0, cell_size)
    xv, zv = np.meshgrid(x, z)
    grid_H, grid_W = zv.shape

    # =============================== initialize the habitat environment ============================
    config = habitat.get_config(config_paths='../bellman/habitat_tools/configs/habitat_env/build_map_mp3d.yaml')
    config.defrost()
    config.SIMULATOR.SCENE = f'../scene_datasets/mp3d/{scene}/{scene}.glb'
    config.SIMULATOR.SCENE_DATASET = '../scene_datasets/mp3d/mp3d_annotated_basis.scene_dataset_config.json'
    config.freeze()

    env = habitat.sims.make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR)
    env.reset()

    sem_map_npy = np.load(f'../bellman/habitat_tools/output/semantic_map/{scene}_0/BEV_semantic_map.npy', allow_pickle=True).item()
    map_data = read_sem_map_npy(sem_map_npy)

    # initialize the occupancy grid
    occ_map = np.zeros((grid_H, grid_W), dtype=int)

    # ============================= traverse the environment =========================================
    for grid_z, grid_x in tqdm(product(range(grid_H), range(grid_W)), total=grid_H * grid_W):

            x = xv[grid_z, grid_x] + cell_size / 2.
            z = zv[grid_z, grid_x] + cell_size / 2.
            y = height

            agent_pos = np.array([x, y, z])
            flag_nav = env.is_navigable(agent_pos)

            if flag_nav:
                x = xv[grid_z, grid_x] + cell_size / 2.
                z = zv[grid_z, grid_x] + cell_size / 2.
                # convert environment pose to map coordinates
                x_coord, z_coord = pose_to_coords(
                    (x, z), map_data, flag_cropped=False)
                occ_map[z_coord, x_coord] = 1

    # cut occupancy map and make it same size as the semantic map
    coords_range = map_data['coords_range']
    occ_map = occ_map[coords_range[1]:coords_range[3] +
                        1, coords_range[0]:coords_range[2] + 1]
    env.close()
    save_occ_map_through_plt(occ_map, f'{scene}_occ_map.png')

    return occ_map, map_data

def save_occ_map_through_plt(img, name):
    """ save the figure img at directory 'name' using matplotlib"""
    fig, ax = plt.subplots(nrows=1, ncols=1)
    ax.imshow(img, cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(name)
    plt.close()

def visualize_mp3d_graph(scan, occ_map, map_data):
    connectivity_dir = f"../VLN-DUET/datasets/R2R/connectivity"
    G = load_nav_graph(connectivity_dir, scan)

    occ_map = occ_map.astype(np.uint8)
    occ_map = cv2.cvtColor(occ_map * 255, cv2.COLOR_GRAY2BGR)

    for nodes in G.nodes():
        x, y, z = G.nodes[nodes]['position']
        px, py, pz = x, z, -y

        x_coord, z_coord = pose_to_coords((px, pz), map_data)

        cv2.circle(occ_map, (x_coord, z_coord), 4, (0, 0, 255), -1)
    
    for edge in G.edges():
        x1, y1, z1 = G.nodes[edge[0]]['position']
        x2, y2, z2 = G.nodes[edge[1]]['position']
        p1x, p1y, p1z = x1, z1, -y1
        p2x, p2y, p2z = x2, z2, -y2

        p1_coord = pose_to_coords((p1x, p1z), map_data)
        p2_coord = pose_to_coords((p2x, p2z), map_data)

        cv2.line(occ_map, p1_coord, p2_coord, (0, 255, 0), 1)

    save_occ_map_through_plt(occ_map, f'{scan}_graph.png')


def main():
    scene = '8WUmhLawc2A'
    height = 0.10065700113773346
    occ_map, map_data = get_occ_map(scene, height)

    visualize_mp3d_graph(scene, occ_map, map_data)

if __name__ == '__main__':
    main()
