import os
import json
import cv2
import yaml
import numpy as np
from tqdm import tqdm
from sklearn.cluster import AgglomerativeClustering
from omegaconf import DictConfig
from hovsg.graph.graph import Graph
from hovsg.graph.navigation_graph import NavigationGraph

def visualize_graph(nodes, top_down_map, nav_graph, name):
    fig = top_down_map.copy().astype(np.uint8)
    if np.max(top_down_map) <= 1:
        fig = (top_down_map.copy() * 255).astype(np.uint8)
    if len(fig.shape) == 2:
        fig = cv2.cvtColor(fig, cv2.COLOR_GRAY2BGR)

    for node in nodes:
        grid_pos = nav_graph.to_grid(node)
        grid_pos = np.int32(grid_pos)[[0,2]]
        cv2.circle(fig, tuple(grid_pos), 3, (0, 255, 0), -1)
    
    cv2.imwrite(os.path.join("navigation_graph", f"{name}.png"), fig)

def visualize(points, name):
    param_path="config/Nav3DSG.yaml"
    with open(param_path, "r") as f:
        params = yaml.safe_load(f)
    params = DictConfig(params)
    hovsg = Graph(params)
    hovsg.load_graph("../HOV-SG/data/scene_graphs/hm3dsem/zsNo4HB9uLZ/graph")

    floor = hovsg.floors[0]
    nav_graph = NavigationGraph(floor.pcd, cell_size=0.03)
    floor_info = {
            "floor_zero_level": floor.floor_zero_level,
            "floor_height": floor.floor_height,
        }
    top_down_map = nav_graph.get_top_down_rgb_map(floor.pcd, floor_info, 'navigation_graph')
    visualize_graph(points, top_down_map, nav_graph, name)

trajs_folder = "../HOV-SG/data/final_traj_list"
files = os.listdir(trajs_folder)

for distance_threshold in tqdm([0.5, 1.0, 1.5, 2.0, 3.0, 5.0], total=6, position=0, leave=True):
    save_dir = os.path.join("navigation_graph", "traj_graph", f"distance_{distance_threshold}")
    os.makedirs(save_dir, exist_ok=True)
    for file in tqdm(files, total=len(files), position=1, leave=False):
        traj = np.loadtxt(os.path.join(trajs_folder, file))
        positions = traj[:, :3]

        clustering = AgglomerativeClustering(
            n_clusters=None,       # Let the algorithm determine the number of clusters
            distance_threshold=distance_threshold,  # Set the threshold for clustering
            linkage='complete',      # Use single linkage to merge closest points
            metric='euclidean'   # Measure distance using Euclidean metric
        )
        
        # Fit the model and predict clusters
        labels = clustering.fit_predict(positions)
        
        # Compute the centroid of each cluster
        clustered_viewpoints = []
        for label in np.unique(labels):
            cluster_points = positions[labels == label]
            centroid = np.mean(cluster_points, axis=0)
            clustered_viewpoints.append(centroid)

        clustered_viewpoints = np.array(clustered_viewpoints)

        nodes = {idx: clustered_viewpoints[idx].tolist() for idx in range(len(clustered_viewpoints))}
        # Save the clustered viewpoints
        save_path = os.path.join(save_dir, file.replace(".txt", ".json"))
        with open(save_path, 'w') as f:
            json.dump(nodes, f, indent=4)

    
    