import os
import cv2
import json
import math
import yaml
import networkx as nx
import numpy as np
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from itertools import product
from omegaconf import DictConfig
from hovsg.graph.graph import Graph
from hovsg.graph.navigation_graph import NavigationGraph

def load_nav_graph(connectivity_dir, scan):
    ''' Load connectivity graph for each scan '''

    def distance(pose1, pose2):
        ''' Euclidean distance between two graph poses '''
        return ((pose1['pose'][3]-pose2['pose'][3])**2\
            + (pose1['pose'][7]-pose2['pose'][7])**2\
            + (pose1['pose'][11]-pose2['pose'][11])**2)**0.5

    with open(os.path.join(connectivity_dir, '%s_connectivity.json' % scan)) as f:
        G = nx.Graph()
        positions = {}
        data = json.load(f)
        for i,item in enumerate(data):
            if item['included']:
                for j,conn in enumerate(item['unobstructed']):
                    if conn and data[j]['included']:
                        positions[item['image_id']] = np.array([item['pose'][3],
                                item['pose'][7], item['pose'][11]])
                        assert data[j]['unobstructed'][i], 'Graph should be undirected'
                        G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j]))
        nx.set_node_attributes(G, values=positions, name='position')
    
    return G

def pose_to_coords(cur_pose,
                    map_data,
                    cell_size=0.05,
                    flag_cropped=True):
    """
    convert pose (X, Z) in the habitat environment to the cell location 'coords' on the map.
    """
    tx, tz = cur_pose[:2]
    tz = -tz

    pose_range = map_data['pose_range']
    coords_range = map_data['coords_range']
    wh = map_data['wh']

    if flag_cropped:
        x_coord = math.floor((tx - pose_range[0]) / cell_size - coords_range[0])
        z_coord = math.floor((wh[0] - (tz - pose_range[1]) / cell_size) -
                        coords_range[1])
    else:
        x_coord = math.floor((tx - pose_range[0]) / cell_size)
        z_coord = math.floor(wh[0] - (tz - pose_range[1]) / cell_size)

    if len(cur_pose) == 3:
        yaw = cur_pose[2]
        map_yaw = -yaw
        return (x_coord, z_coord, map_yaw)
    else:
        return (x_coord, z_coord)

def visualize_mp3d_graph(scan, top_down_map, nav_graph):
    fig = top_down_map.copy().astype(np.uint8)
    if np.max(top_down_map) <= 1:
        fig = (top_down_map.copy() * 255).astype(np.uint8)
    if len(fig.shape) == 2:
        fig = cv2.cvtColor(fig, cv2.COLOR_GRAY2BGR)

    connectivity_dir = f"../VLN-DUET/datasets/R2R/connectivity"
    mp3d_graph = load_nav_graph(connectivity_dir, scan)

    for edge in mp3d_graph.edges():
        n1, n2 = edge
        x1, y1, z1 = mp3d_graph.nodes[n1]['position']
        x2, y2, z2 = mp3d_graph.nodes[n2]['position']
        p1 = (x1, z1, -y1)
        p2 = (x2, z2, -y2)
        grid_pos1 = nav_graph.to_grid(p1)
        grid_pos2 = nav_graph.to_grid(p2)

        grid_pos1 = np.int32(grid_pos1)[[0,2]]
        grid_pos2 = np.int32(grid_pos2)[[0,2]]

        cv2.line(fig, tuple(grid_pos1), tuple(grid_pos2), (0, 0, 255), 1)
        cv2.circle(fig, tuple(grid_pos1), 2, (0, 255, 0), -1)
        cv2.circle(fig, tuple(grid_pos2), 2, (0, 255, 0), -1)
    
    cv2.imwrite(os.path.join("navigation_graph", f"mp3d_rgb.png"), fig)

def load_voronoi_graph(graph_path):
    with open(graph_path, "r") as f:
        graph_data = json.load(f)

    # Convert back to a NetworkX graph
    loaded_graph = nx.node_link_graph(graph_data)
    return loaded_graph

def main():
    param_path="config/Nav3DSG.yaml"
    with open(param_path, "r") as f:
        params = yaml.safe_load(f)
    params = DictConfig(params)
    hovsg = Graph(params)
    hovsg.load_graph("../HOV-SG/data/scene_graphs/hm3dsem/zsNo4HB9uLZ/graph")

    floor = hovsg.floors[0]
    nav_graph = NavigationGraph(floor.pcd, cell_size=0.03)
    floor_info = {
            "floor_zero_level": floor.floor_zero_level,
            "floor_height": floor.floor_height,
        }
    top_down_map = nav_graph.get_top_down_rgb_map(floor.pcd, floor_info, 'navigation_graph')

    # get hovsg graph
    graph_path = "../HOV-SG/data/scene_graphs/hm3dsem/zsNo4HB9uLZ/graph/nav_graph/sparse_voronoi_graph.json"
    hovsg_graph = load_voronoi_graph(graph_path)
    nav_graph.draw_graph_on_map(top_down_map, hovsg_graph, 'navigation_graph', "sparse_vor_rgb")

    scene = 'zsNo4HB9uLZ'
    visualize_mp3d_graph(scene, top_down_map, nav_graph)


if __name__ == '__main__':
    main()
