import os
import glob
import json
import hydra
import numpy as np
from collections import defaultdict
import open3d as o3d
from omegaconf import DictConfig
from hovsg.graph.graph import Graph

@hydra.main(version_base=None, config_path="../config", config_name="visualize_query_graph")
def main(params: DictConfig):
    hovsg = Graph(params)
    hovsg.load_graph(params.main.graph_path)
    # generate room names
    hovsg.generate_room_names(
            generate_method="view_embedding",
            default_room_types=[
                "bathroom", "bedroom", "closet", "dining room", "entryway", 
                "family room", "garage", "hallway", "library", "laundry room", 
                "kitchen", "living room", "meeting room", "lounge", "office", 
                "porch", "gameroom", "stairs", "toilet", "utility_room", 
                "home theater", "gym", "outdoor area", "balcony",
                "bar", "classroom", "dining booth", "spa", "junk"
            ])
    # Load the full point cloud
    full_pcd = o3d.io.read_point_cloud(os.path.join(params.main.graph_path[:-6], "full_pcd.ply"))
    print("The full point cloud has", len(full_pcd.points), "points")
    o3d.visualization.draw_geometries([full_pcd])

    # Load paths to floor PLY files and corresponding JSON metadata
    floors_ply_paths = sorted(glob.glob(os.path.join(params.main.graph_path, "floors", "*.ply")))
    floors_info_paths = sorted(glob.glob(os.path.join(params.main.graph_path, "floors", "*.json")))

    # Initialize data structures for storing point clouds and metadata
    floor_pcds = {}
    floor_infos = {}
    hier_topo = defaultdict(dict)
    init_offset = np.array([7.0, 2.5, 4.0])  # Initial offset for visualization

    # Process each floor
    for counter, (ply_path, info_path) in enumerate(zip(floors_ply_paths, floors_info_paths)):
        with open(info_path, "r") as fp:
            floor_info = json.load(fp)
        # Store relevant floor metadata
        floor_infos[floor_info["floor_id"]] = {
            k: v for k, v in floor_info.items() if k in ["floor_id", "name", "rooms", "floor_height", "floor_zero_level", "vertices"]
        }
        # Apply visualization offset to each floor
        floor_infos[floor_info["floor_id"]]["viz_offset"] = init_offset * counter
        for r_id in floor_info["rooms"]:
            hier_topo[floor_info["floor_id"]][r_id] = []

        # Load the floor point cloud
        floor_pcds[floor_info["floor_id"]] = o3d.io.read_point_cloud(ply_path)

    for floor_id, floor_pcd in floor_pcds.items():
        print("Floor ID:", floor_id)
        print("number of rooms in floor {} is {}".format(floor_id, len(hovsg.floors[int(floor_id)].rooms)))
        o3d.visualization.draw_geometries([floor_pcd])

    # Load paths to room PLY files and corresponding JSON metadata
    rooms_ply_paths = sorted(glob.glob(os.path.join(params.main.graph_path, "rooms", "*.ply")))
    rooms_info_paths = sorted(glob.glob(os.path.join(params.main.graph_path, "rooms", "*.json")))

    # Initialize data structures for storing room point clouds and metadata
    room_pcds = {}
    room_infos = {}

    # Process each room
    for ply_path, info_path in zip(rooms_ply_paths, rooms_info_paths):
        with open(info_path, "r") as fp:
            room_info = json.load(fp)
        # Store relevant room metadata
        room_infos[room_info["room_id"]] = {
            k: v for k, v in room_info.items() if k in ["room_id", "name", "floor_id", "room_height", "room_zero_level", "vertices"]
        }
        for o_id in room_info["objects"]:
            hier_topo[room_info["floor_id"]][room_info["room_id"]].append(o_id)

        # Load the room point cloud and apply filtering
        orig_cloud = o3d.io.read_point_cloud(ply_path)
        orig_cloud_xyz = np.asarray(orig_cloud.points)
        below_ceiling_filter = (
            orig_cloud_xyz[:, 1]
            < room_infos[room_info["room_id"]]["room_zero_level"]
            + room_infos[room_info["room_id"]]["room_height"]
            - 0.4
        )
        room_pcds[room_info["room_id"]] = orig_cloud.select_by_index(np.where(below_ceiling_filter)[0])
        cloud_xyz = np.asarray(room_pcds[room_info["room_id"]].points)
        cloud_xyz += floor_infos[room_info["floor_id"]]["viz_offset"]
        room_pcds[room_info["room_id"]].colors = o3d.utility.Vector3dVector(
            np.clip(np.array(room_pcds[room_info["room_id"]].colors) * 1.2, 0.0, 1.0)
        )

    for room in hovsg.rooms:
        print("Room ID:", room.room_id, "Room Name:", room.name, 'Floor ID:', room.floor_id)
        print("number of objects in room {} is {}".format(room.room_id, len(room.objects)))
        o3d.visualization.draw_geometries([room_pcds[room.room_id]])

if __name__ == "__main__":
    main()