# # import os
# # import sys

# # # Add the local ov-seg repo to Python path
# # current_dir = os.path.dirname(os.path.abspath(__file__))
# # ovseg_repo = os.path.join(current_dir, "ov-seg")
# # if ovseg_repo not in sys.path:
# #     sys.path.append(ovseg_repo)

# import os
# import sys
# sys.path.append(os.path.join(os.path.dirname(__file__), "ov-seg"))

# from detectron2.config import get_cfg
# from open_vocab_seg import add_ovseg_config
# from detectron2.modeling import build_model

# # from open_vocab_seg.utils.arguments import default_argument_parser
# # from open_vocab_seg.utils.predictor import OVSegDemo  # name may be VisualizationDemo/OVSegDEMO in your repo


# # from ovseg.config import add_ovseg_config

# import os
# import numpy as np
# import yaml
# from PIL import Image

# import cv2
# import argparse
# import json
# import torch
# import gc
# import copy
# import pandas as pd
# from sklearn.cluster import DBSCAN
# from semantic_sam import (
#     prepare_image,
#     plot_results,
#     build_semantic_sam,
#     SemanticSamAutomaticMaskGenerator,
#     SemanticSAMPredictor,
# )
# import networkx as nx
# import psutil

# # --- OVSeg / Detectron2 imports (adjust to your installation) ---
# # from detectron2.config import get_cfg
# # from ovseg.config import add_ovseg_config
# # from ovseg.modeling import build_ovseg_model

# device = "cuda" if torch.cuda.is_available() else "cpu"


# def build_ovseg_model(cfg, device="cuda"):
#     """
#     Build an OVSeg model from an already-prepared cfg.
#     """
#     model = build_model(cfg)
#     model.to(device)
#     model.eval()
#     return model


# def _load_pose(path, idx, dataset):
#     """
#     Load camera pose (4x4 transformation matrix) for a given frame index.
#     Supports Replica and ScanNet datasets.
#     """
#     if dataset == "Replica":
#         path = os.path.join(path, "traj.txt")
#         with open(path, "r") as file:
#             lines = file.readlines()
#             if 0 <= idx < len(lines):
#                 line = lines[idx]
#                 values = [float(val) for val in line.split()]
#                 transformation_matrix = np.array(values).reshape((4, 4))
#                 return transformation_matrix
#     elif dataset == "ScanNet":
#         path = os.path.join(path, str(idx) + ".txt")
#         transformation_matrix = np.loadtxt(path).reshape(4, 4)
#         return transformation_matrix


# def _load_depth_intrinsics(path, dataset):
#     """
#     Load depth camera intrinsics and scale factor for depth values.
#     Different format for Replica vs ScanNet.
#     """
#     if dataset == "Replica":
#         with open(path, "r") as file:
#             data = json.load(file)
#             camera_params = data.get("camera")
#             if camera_params:
#                 fx = camera_params.get("fx")
#                 fy = camera_params.get("fy")
#                 cx = camera_params.get("cx")
#                 cy = camera_params.get("cy")
#                 scale = camera_params.get("scale")
#                 K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
#                 K = np.array(K)
#                 return K, scale
#     elif dataset == "ScanNet":
#         intrinsic_depth = np.loadtxt(path)
#         scale = 1000.0
#         return intrinsic_depth, scale


# def overlap(mask_i, mask_j, iom_threshold):
#     """
#     Compute 'intersection over minimum' (IoM) between mask_i and mask_j.
#     Return:
#       0  -> no significant overlap
#       1  -> mask_i mostly inside mask_j
#       2  -> mask_j mostly inside mask_i
#     """
#     intersection = np.sum(np.multiply(mask_i, mask_j))
#     sum_i = np.sum(mask_i)
#     sum_j = np.sum(mask_j)
#     if sum_i > sum_j:
#         iom = intersection / sum_j
#         if iom > iom_threshold:
#             return 2
#         else:
#             return 0
#     else:
#         iom = intersection / sum_i
#         if iom > iom_threshold:
#             return 1
#         else:
#             return 0


# def remove_overlapped_masks(results, iom_threshold, area_threshold):
#     """
#     Remove overlapping and tiny masks.
#     - results: list of 2D binary masks (numpy arrays).
#     Returns:
#       torch.Tensor of kept masks [N, H, W]
#       torch.Tensor of corresponding boxes [N, 4] (xmin, ymin, xmax, ymax)
#     """
#     mask_shape = results[0].shape
#     boxes = torch.zeros(len(results), 4)
#     new_masks = torch.zeros(len(results), mask_shape[0], mask_shape[1])
#     removed = []
#     for i in range(len(results)):
#         remove_i = []
#         remove_mask_i = []
#         mask_i = results[i].astype(int)

#         area_ratio = np.sum(mask_i) / (np.shape(mask_i)[0] * np.shape(mask_i)[1])

#         if area_ratio < area_threshold:
#             removed.append(i)
#         else:
#             for j in range(len(results)):
#                 if j != i:
#                     mask_j = results[j].astype(int)
#                     index = overlap(mask_i, mask_j, iom_threshold=iom_threshold)
#                     if index == 2:
#                         if j not in remove_i:
#                             remove_i.append(j)
#                             remove_mask_i.append(mask_j)

#             for mask in remove_mask_i:
#                 mask_i = mask_i * (1 - mask)
#             if np.sum(mask_i) > 0:
#                 new_masks[i] = torch.tensor(mask_i)

#                 rows, cols = np.where(mask_i)
#                 min_row, max_row = rows.min(), rows.max()
#                 min_col, max_col = cols.min(), cols.max()
#                 box = [min_col, min_row, max_col, max_row]
#                 boxes[i] = torch.tensor(box)

#     the_masks = []
#     the_boxes = []
#     for i in range(new_masks.size()[0]):
#         mask = new_masks[i]
#         box = boxes[i]

#         if (i not in removed) and torch.sum(mask) > 0:
#             the_masks.append(mask)
#             the_boxes.append(box)

#     return torch.stack(the_masks), torch.stack(the_boxes)


# def remove_tiny_masks(masks, boxes, area_threshold):
#     """
#     Remove masks whose area ratio is below area_threshold.
#     """
#     the_masks = []
#     the_boxes = []
#     for i in range(len(masks)):
#         mask = masks[i]
#         box = boxes[i]
#         area_ratio = torch.sum(mask) / (mask.size()[0] * mask.size()[1])

#         if area_ratio > area_threshold:
#             the_masks.append(mask)
#             the_boxes.append(box)

#     return the_masks, the_boxes


# def dbscan_mask_denoise(mask, eps, min_samples):
#     """
#     Apply 2D DBSCAN on mask pixels to remove noise and split into clusters.
#     Returns list of binary masks (one per cluster).
#     """
#     coords = np.column_stack(np.nonzero(mask))

#     if coords.shape[0] == 0:
#         return mask

#     clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(coords)
#     labels = clustering.labels_

#     cleaned_mask = -1 * np.ones_like(mask)
#     for cluster_id in np.unique(labels):
#         if cluster_id == -1:
#             continue
#         cluster_points = coords[labels == cluster_id]
#         cleaned_mask[cluster_points[:, 0], cluster_points[:, 1]] = cluster_id

#     final_masks = []
#     for i in range(len(np.unique(cleaned_mask))):
#         if i != -1:
#             final_masks.append((cleaned_mask == i).astype(int))

#     return final_masks


# def dbscan_3d(points, colors, eps, min_samples):
#     """
#     Apply DBSCAN in 3D on point cloud to find a dominant cluster.
#     Returns:
#       cluster_points, cluster_colors, threshold (cluster_size / total_size)
#     """
#     points = points.astype(np.float16)
#     dtype = type(points[0])
#     n_samples = np.shape(points)[0]
#     bytes_per_entry = np.dtype(dtype).itemsize
#     total_bytes = n_samples ** 2 * bytes_per_entry
#     if total_bytes / (1024 ** 3) > 600:
#         return [], [], -1

#     clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
#     labels = clustering.labels_

#     cluster_points = points[labels == 0]
#     cluster_colors = colors[labels == 0]
#     threshold = cluster_points.shape[0] / points.shape[0]

#     return cluster_points, cluster_colors, threshold


# def remove_side_masks(masks, boxes, remove_thr, side_thr, ratio_thr):
#     """
#     Remove masks that lie mostly on the image borders.
#     """
#     removed = []
#     mask_shape = masks[0].size()

#     for i in range(len(masks)):
#         mask = masks[i]
#         H, W = mask.shape

#         y_coords = torch.arange(H).view(-1, 1).expand(H, W)
#         x_coords = torch.arange(W).view(1, -1).expand(H, W)

#         dist_top = y_coords
#         dist_bottom = H - 1 - y_coords
#         dist_left = x_coords
#         dist_right = W - 1 - x_coords

#         dist_to_edge = torch.minimum(
#             torch.minimum(dist_top, dist_bottom),
#             torch.minimum(dist_left, dist_right),
#         )

#         edge_mask = (dist_to_edge < side_thr).long()

#         final_mask = mask * edge_mask

#         count = final_mask.sum().item()
#         sum_mask = mask.sum().item()

#         ratio = count / sum_mask
#         ratio1 = sum_mask / (final_mask.size()[0] * final_mask.size()[1])

#         if (ratio > remove_thr) and (ratio1 < ratio_thr):
#             removed.append(i)

#     new_masks = []
#     new_boxes = []
#     for i in range(len(masks)):
#         mask = masks[i]
#         box = boxes[i]
#         if i not in removed:
#             new_masks.append(np.array(mask))
#             new_boxes.append(np.array(box))
#     return new_masks, new_boxes


# def extend_images(image, boxes, masks, extension_ratio, hide_mask=False, hide_others=False):
#     """
#     Used only for visualization in the original CLIP code.
#     We keep it in case you need it, but OVSeg does not use it.
#     """
#     extended_images = []
#     ratios = []
#     for i in range(len(boxes)):
#         new_image = copy.deepcopy(image)
#         if hide_mask:
#             new_image = (np.array(1 - masks[i])[:, :, np.newaxis]) * new_image
#             new_image = new_image.clip(0, 255).astype(np.uint8)
#         if hide_others:
#             new_image = (np.array(masks[i])[:, :, np.newaxis]) * new_image
#             new_image = new_image.clip(0, 255).astype(np.uint8)

#         center_x = (boxes[i][2] + boxes[i][0]) / 2
#         center_y = (boxes[i][3] + boxes[i][1]) / 2
#         width = (boxes[i][2] - boxes[i][0]) / 2
#         height = (boxes[i][3] - boxes[i][1]) / 2

#         new_x1 = int(center_x - extension_ratio * width)
#         new_x2 = int(center_x + extension_ratio * width)
#         new_y1 = int(center_y - extension_ratio * height)
#         new_y2 = int(center_y + extension_ratio * height)
#         x_margin = image.shape[1]
#         y_margin = image.shape[0]

#         new_x1 = max(new_x1, 0)
#         new_x2 = min(new_x2, x_margin)
#         new_y1 = max(new_y1, 0)
#         new_y2 = min(new_y2, y_margin)

#         final_image = new_image[new_y1:new_y2, new_x1:new_x2]

#         if width > height:
#             final_image = cv2.resize(final_image, (1200, int(1200 / width * height)))
#         else:
#             final_image = cv2.resize(final_image, (int(800 / height * width), 800))

#         extended_images.append(final_image)
#         ratio = np.sum(masks[i]) / ((new_y2 - new_y1) * (new_x2 - new_x1))
#         ratios.append(ratio.item())

#     return extended_images, ratios


# def point_cloud(depth, scale, camera_intristics, mask, camera_pose, colors, dataset):
#     """
#     Convert depth + intrinsics + camera pose + mask into a 3D point cloud.
#     """
#     mask = mask.to("cuda")
#     camera_matrix = torch.tensor(camera_intristics).to("cuda")
#     depth = torch.tensor(depth, dtype=torch.float32).to("cuda")
#     colors = torch.tensor(colors).to("cuda")
#     camera_pose = torch.tensor(camera_pose).to("cuda")

#     y, x = torch.meshgrid(
#         torch.arange(depth.size()[0]), torch.arange(depth.size()[1]), indexing="ij"
#     )
#     x = x.to("cuda")
#     y = y.to("cuda")

#     if depth.dim() == 3:
#         depth = depth[:, :, 0]

#     depth = depth.float() / scale

#     depth_mask1 = (depth > 0).long()
#     mask = mask * depth_mask1
#     mask = mask > 0

#     X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
#     Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
#     Z = depth

#     points = torch.stack(
#         (X.view(-1), Y.view(-1), Z.view(-1), torch.ones_like(X.view(-1))), dim=-1
#     )

#     points = torch.matmul(camera_pose, points.T)
#     points = points.T

#     points = points[mask.view(-1)]
#     colors = colors[mask]
#     colors = colors.view(-1, 3)
#     points = points.view(-1, 4)

#     return points[:, :3].cpu().numpy(), colors.cpu().numpy()


# def points_to_grid(points, resolution):
#     """
#     Snap 3D points to a 3D grid with cell size = resolution.
#     """
#     converted_points = ((points / resolution).astype(int).astype(np.float32)) * resolution
#     return converted_points


# def grid_indices(points, params, res):
#     """
#     Convert 3D points into voxel grid indices.
#     """
#     points = torch.tensor(points).to(device)
#     x_min, y_min, z_min, x_max, y_max, z_max = params
#     indices = torch.zeros_like(torch.tensor(points))
#     indices[:, 0] = ((points[:, 0] - x_min) / res).long()
#     indices[:, 1] = ((points[:, 1] - y_min) / res).long()
#     indices[:, 2] = ((points[:, 2] - z_min) / res).long()
#     return indices.long()


# def voxelize_batch(points, X, Y, Z, params):
#     """
#     Convert a point cloud into a (X,Y,Z) voxel grid.
#     """
#     voxel = torch.zeros((X, Y, Z), device=device)
#     indices = grid_indices(points, params, resolution)
#     voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
#     return voxel


# def geometry_overlap(points, resolution, overlap_thr1, overlap_thr2):
#     """
#     Compute pairwise geometric overlap between all point clouds.
#     """
#     num_points = len(points)
#     adjacency = np.zeros((num_points, num_points))
#     x_min = np.zeros(num_points)
#     y_min = np.zeros(num_points)
#     z_min = np.zeros(num_points)
#     x_max = np.zeros(num_points)
#     y_max = np.zeros(num_points)
#     z_max = np.zeros(num_points)

#     for i in range(num_points):
#         x_min[i] = np.min(points[i][:, 0])
#         y_min[i] = np.min(points[i][:, 1])
#         z_min[i] = np.min(points[i][:, 2])
#         x_max[i] = np.max(points[i][:, 0])
#         y_max[i] = np.max(points[i][:, 1])
#         z_max[i] = np.max(points[i][:, 2])

#     for i in range(0, num_points):
#         print(i, " point clouds processed from ", num_points, "!")

#         for j in range(0, num_points):
#             if not (
#                 (x_min[i] > x_max[j] or x_max[i] < x_min[j])
#                 or (y_min[i] > y_max[j] or y_max[j] < y_min[j])
#                 or (z_min[i] > z_max[j] or z_max[i] < z_min[j])
#             ):
#                 my_x_min = np.min([x_min[i], x_min[j]]) - 0.2
#                 my_y_min = np.min([y_min[i], y_min[j]]) - 0.2
#                 my_z_min = np.min([z_min[i], z_min[j]]) - 0.2
#                 my_x_max = np.max([x_max[i], x_max[j]]) + 0.2
#                 my_y_max = np.max([y_max[i], y_max[j]]) + 0.2
#                 my_z_max = np.max([z_max[i], z_max[j]]) + 0.2
#                 X = int((my_x_max - my_x_min) / resolution)
#                 Y = int((my_y_max - my_y_min) / resolution)
#                 Z = int((my_z_max - my_z_min) / resolution)
#                 params = (my_x_min, my_y_min, my_z_min, my_x_max, my_y_max, my_z_max)

#                 v_i = voxelize_batch(points[i], X, Y, Z, params)
#                 v_j = voxelize_batch(points[j], X, Y, Z, params)

#                 overlaps = torch.sum(v_i * v_j)
#                 overlaps_i = overlaps / torch.sum(v_i)
#                 overlaps_j = overlaps / torch.sum(v_j)

#                 if (overlaps_i > overlap_thr1) and (overlaps_j > overlap_thr1) and (
#                     torch.abs(overlaps_j - overlaps_i) < overlap_thr2
#                 ):
#                     adjacency[i, j] = 1

#                 del v_j, overlaps_i, overlaps_j, overlaps
#                 torch.cuda.empty_cache()

#     return adjacency


# def merge_points(adj_matrix):
#     """
#     Given adjacency matrix, find connected components.
#     """
#     G = nx.from_numpy_array(adj_matrix)
#     components = list(nx.connected_components(G))
#     return components


# # --------- NEW: OVSeg per-mask embedding ---------

# # def ovseg_mask_embeddings(image_bgr, masks, ovseg_model, device="cuda"):
# #     """
# #     Run OVSeg on the full image once, then pool logits over each mask.
# #     image_bgr: HxWx3 uint8 (OpenCV BGR)
# #     masks: list of HxW numpy arrays (0/1)
# #     Returns: numpy array [num_masks, C] where C is #classes in OVSeg logits.
# #     """
# #     if len(masks) == 0:
# #         return np.zeros((0, 0), dtype=np.float32)

# #     # Convert BGR -> RGB because most models expect RGB
# #     image_rgb = image_bgr[:, :, ::-1]
# #     H, W, _ = image_rgb.shape

# #     img_t = torch.from_numpy(image_rgb).permute(2, 0, 1).float().to(device) / 255.0

# #     with torch.no_grad():
# #         pred = ovseg_model([{"image": img_t, "height": H, "width": W}])[0]
# #         sem_logits = pred["sem_seg"]  # [C, H, W]

# #     C = sem_logits.shape[0]
# #     emb_list = []

# #     for m in masks:
# #         m_bool = torch.from_numpy(m.astype(bool)).to(device)
# #         if m_bool.sum() == 0:
# #             emb_list.append(torch.zeros(C, device=device))
# #         else:
# #             # mean logits over the mask region
# #             cls_scores = sem_logits[:, m_bool].mean(dim=1)
# #             emb_list.append(cls_scores)

# #     embs = torch.stack(emb_list, dim=0).cpu().numpy()  # [num_masks, C]
# #     return embs

# def ovseg_mask_embeddings(image_bgr, masks, ovseg_model, device="cuda"):
#     if len(masks) == 0:
#         return np.zeros((0, 0), dtype=np.float32)

#     # Convert BGR → RGB and ensure contiguous
#     image_rgb = image_bgr[:, :, ::-1].copy()
#     H, W, _ = image_rgb.shape

#     # img_t = torch.from_numpy(image_rgb).permute(2, 0, 1).float().to(device) / 255.0
#     img_t = torch.from_numpy(image_rgb).permute(2, 0, 1).float().to(device)

#     with torch.no_grad():
#         batched_inputs = [{
#             "image": img_t,
#             "height": H,
#             "width": W,
#             "meta": {"dataset_name": "ade20k_sem_seg_val"},
#         }]

#         pred = ovseg_model(batched_inputs)[0]
#         sem_logits = pred["sem_seg"]  # [C, H, W]

#     C = sem_logits.shape[0]
#     emb_list = []

#     for m in masks:
#         m_bool = torch.from_numpy(m.astype(bool)).to(device)
#         if m_bool.sum() == 0:
#             emb_list.append(torch.zeros(C, device=device))
#         else:
#             cls_scores = sem_logits[:, m_bool].mean(dim=1).detach()
#             emb_list.append(cls_scores)

#     # Stack + detach
#     embs = torch.stack(emb_list, dim=0).detach().cpu().numpy()

#     return embs



# def object_embeddings(components, points, embeddings, colors):
#     """
#     Aggregate points, colors and embeddings over connected components.
#     """
#     final_points = []
#     final_colors = []
#     final_embeddings = []
#     count = []
#     for i in range(len(components)):
#         new_points = []
#         new_colors = []
#         new_embeddings = []
#         for j in components[i]:
#             new_points.append(points[j])
#             new_embeddings.append(embeddings[j])
#             new_colors.append(colors[j])
#         final_colors.append(np.vstack(new_colors))
#         final_points.append(np.vstack(new_points))
#         final_embeddings.append(np.mean(np.array(new_embeddings), axis=0))
#         count.append(len(new_embeddings))
#     return final_colors, final_points, final_embeddings, count


# def masks_detection(
#     extension_ratio_hide,
#     extension_ratio_s,
#     extension_ratio_l,
#     extension_ratio_h,
#     path,
#     last_idx,
#     step,
#     remove_iou,
#     remove_side_thr,
#     side_thr,
#     dataset,
#     area_threshold,
#     ratio_thr,
#     camera_intristics,
#     scale,
#     resolution,
#     ovseg_model,
# ):
#     """
#     Full per-frame pipeline with SemanticSAM masks and OVSeg embeddings.
#     """
#     if dataset == "Replica":
#         prefix_rgb = "results/frame"
#         prefix_depth = "results/depth"
#         pose_path = path
#         image_list = [
#             f
#             for f in os.listdir(path + "results/")
#             if f.startswith("frame") and os.path.isfile(os.path.join(path + "results/", f))
#         ]
#     elif dataset == "ScanNet":
#         prefix_rgb = "color/"
#         prefix_depth = "depth/"
#         pose_path = path + "pose/"
#         image_list = os.listdir(path + prefix_rgb)

#     mask_generator1 = SemanticSamAutomaticMaskGenerator(
#         build_semantic_sam(model_type="L", ckpt="models/swinl_only_sam_many2many.pth"),
#         level=[3],
#     )
#     mask_generator2 = SemanticSamAutomaticMaskGenerator(
#         build_semantic_sam(model_type="L", ckpt="models/swinl_only_sam_many2many.pth"),
#         level=[4],
#     )
#     mask_generator3 = SemanticSamAutomaticMaskGenerator(
#         build_semantic_sam(model_type="L", ckpt="models/swinl_only_sam_many2many.pth"),
#         level=[6],
#     )

#     my_colors = []
#     my_masks = []
#     my_points = []
#     my_embeddings = []

#     for i in range(0, last_idx, step):
#         print("Processing frame {}/{}".format(i, len(image_list)))
#         image_path = image_list[i]

#         if dataset == "Replica":
#             idx_path = str(image_path[image_path.index("e") + 1 : image_path.index(".")])
#         elif dataset == "ScanNet":
#             idx_path = str(image_path[: image_path.index(".")])

#         rgb_i = cv2.imread(path + prefix_rgb + idx_path + ".jpg")
#         if dataset == "Replica":
#             depth_i = cv2.imread(
#                 path + prefix_depth + idx_path + ".png", cv2.IMREAD_UNCHANGED
#             ).astype(np.double)
#         elif dataset == "ScanNet":
#             depth_i = cv2.imread(
#                 path + prefix_depth + idx_path + ".png", cv2.IMREAD_UNCHANGED
#             ).astype(np.float32)

#         camera_pose_i = _load_pose(pose_path, int(idx_path), dataset)

#         if rgb_i.shape[:2] != depth_i.shape[:2]:
#             rgb_i = cv2.resize(
#                 rgb_i,
#                 (depth_i.shape[1], depth_i.shape[0]),
#                 interpolation=cv2.INTER_LINEAR,
#             )

#         results_i = []
#         original_image, input_image = prepare_image(
#             image_pth=path + prefix_rgb + idx_path + ".jpg"
#         )

#         # ----- Level 3 -----
#         the_masks1 = mask_generator1.generate(input_image)
#         remaining_parts = np.ones_like(the_masks1[0]["segmentation"].astype(int))
#         remaining_parts = cv2.resize(
#             remaining_parts,
#             (rgb_i.shape[1], rgb_i.shape[0]),
#             interpolation=cv2.INTER_NEAREST,
#         )
#         new_remaining_parts = copy.deepcopy(remaining_parts)

#         for k in range(len(the_masks1)):
#             mask = the_masks1[k]["segmentation"]
#             mask = cv2.resize(
#                 mask.astype(int),
#                 (rgb_i.shape[1], rgb_i.shape[0]),
#                 interpolation=cv2.INTER_NEAREST,
#             )
#             if np.sum(mask * new_remaining_parts) / np.sum(mask) > 0.65:
#                 results_i.append(mask)

#         the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
#         results_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)

#         for k in range(the_masks_i.size(0)):
#             mask = the_masks_i[k]
#             new_remaining_parts *= (1 - np.array(mask, dtype=np.int32))

#         del the_masks1
#         del the_masks_i
#         gc.collect()
#         remaining_parts = copy.deepcopy(new_remaining_parts)

#         # ----- Level 4 -----
#         the_masks2 = mask_generator2.generate(input_image)
#         for k in range(len(the_masks2)):
#             mask = the_masks2[k]["segmentation"]
#             mask = cv2.resize(
#                 mask.astype(int),
#                 (rgb_i.shape[1], rgb_i.shape[0]),
#                 interpolation=cv2.INTER_NEAREST,
#             )
#             if np.sum(mask * remaining_parts) / np.sum(mask) > 0.45:
#                 results_i.append(mask)

#         the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
#         results_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)
#         for k in range(the_masks_i.size(0)):
#             mask = the_masks_i[k]
#             new_remaining_parts *= (1 - np.array(mask, dtype=np.int32))

#         del the_masks2
#         gc.collect()
#         remaining_parts = copy.deepcopy(new_remaining_parts)

#         # ----- Level 6 -----
#         the_masks3 = mask_generator3.generate(input_image)
#         for k in range(len(the_masks3)):
#             mask = the_masks3[k]["segmentation"]
#             mask = cv2.resize(
#                 mask.astype(int),
#                 (rgb_i.shape[1], rgb_i.shape[0]),
#                 interpolation=cv2.INTER_NEAREST,
#             )
#             if np.sum(mask * remaining_parts) / np.sum(mask) > 0.35:
#                 results_i.append(mask)

#         del the_masks3
#         gc.collect()

#         the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
#         the_masks_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)

#         masks_i = []
#         boxes_i = []

#         for k in range(len(the_masks_i)):
#             mask = np.array(the_masks_i[k])
#             mask = cv2.resize(
#                 mask.astype(int),
#                 (rgb_i.shape[1], rgb_i.shape[0]),
#                 interpolation=cv2.INTER_NEAREST,
#             )
#             new_masks = dbscan_mask_denoise(mask, 7, 80)

#             for r in range(len(new_masks)):
#                 my_mask = new_masks[r]
#                 if np.sum(my_mask) > 0:
#                     masks_i.append(torch.tensor(my_mask))
#                     rows, cols = np.where(my_mask > 0)
#                     min_row, max_row = rows.min(), rows.max()
#                     min_col, max_col = cols.min(), cols.max()
#                     box = [min_col, min_row, max_col, max_row]
#                     boxes_i.append(torch.tensor(box))

#             del new_masks
#             gc.collect()

#         masks_i, boxes_i = remove_tiny_masks(masks_i, boxes_i, area_threshold)

#         new_masks_i = []
#         new_boxes_i = []
#         new_points_i = []
#         new_colors_i = []

#         for k in range(len(masks_i)):
#             points, colors = point_cloud(
#                 depth_i, scale, camera_intristics, masks_i[k], camera_pose_i, rgb_i, dataset
#             )
#             points = points_to_grid(points, resolution)

#             if points.shape[0] > 0:
#                 points, colors, threshold = dbscan_3d(points, colors, 0.15, 200)
#             else:
#                 threshold = 0.0

#             if threshold > 0.8:
#                 the_mask = masks_i[k]
#                 rows, cols = np.where(the_mask > 0)
#                 min_row, max_row = rows.min(), rows.max()
#                 min_col, max_col = cols.min(), cols.max()

#                 box = [min_col, min_row, max_col, max_row]
#                 new_masks_i.append(masks_i[k].cpu().numpy())
#                 new_boxes_i.append(np.array(box))
#                 new_points_i.append(points)
#                 new_colors_i.append(colors)

#             del points
#             del colors
#             gc.collect()

#         # *** THIS IS THE MAIN CHANGE: OVSEG instead of CLIP ***
#         if len(new_masks_i) > 0:
#             # We only send mask & image to OVSeg, no context-aware CLIP tricks.
#             embedding = ovseg_mask_embeddings(rgb_i, new_masks_i, ovseg_model, device=device)
#         else:
#             embedding = np.zeros((0, 0), dtype=np.float32)

#         my_colors.append(new_colors_i)
#         my_masks.append(new_masks_i)
#         my_points.append(new_points_i)
#         my_embeddings.append(embedding)

#     final_masks = []
#     final_points = []
#     final_embeddings = []
#     final_colors = []
#     for i in range(len(my_masks)):
#         for j in range(len(my_masks[i])):
#             final_masks.append(my_masks[i][j])
#             final_points.append(my_points[i][j])
#             final_embeddings.append(my_embeddings[i][j])
#             final_colors.append(my_colors[i][j])
#     return final_colors, final_masks, final_points, final_embeddings

# # def build_ovseg_model(config_path, weight_path, device="cuda"):
# #     """
# #     Build an OVSeg model from a yaml config + weights using detectron2.build_model.
# #     """
# #     cfg = get_cfg()
# #     add_ovseg_config(cfg)
# #     cfg.merge_from_file(config_path)
# #     cfg.MODEL.WEIGHTS = weight_path
# #     cfg.freeze()

# #     model = build_model(cfg)
# #     model.to(device)
# #     model.eval()
# #     return model



# def main():
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--scene", type=str)
#     parser.add_argument("--dataset_path", type=str)
#     parser.add_argument("--path", type=str, default="")
#     parser.add_argument("--dataset", type=str)
#     args = parser.parse_args()

#     path = args.path
#     dataset = args.dataset
#     scene = args.scene
#     dataset_path = args.dataset_path

#     # ----- Build OVSeg model -----
#     # Adjust these to your actual config + checkpoint
#     # OVSEG_CONFIG_PATH = os.path.join(path, "core_configs", "ovseg_config.yaml")
#     OVSEG_CONFIG_PATH = os.path.join("/mnt/data_pool/exploration/CORE-3D/ov-seg/configs/ovseg_swinB_vitL_bs32_120k.yaml")

#     # OVSEG_WEIGHTS_PATH = os.path.join(path, "models", "ovseg.pth")
#     OVSEG_WEIGHTS_PATH ="./ov-seg/models/ovseg_swinbase_vitL14_ft_mpt.pth"

#     cfg = get_cfg()
#     add_ovseg_config(cfg)
#     cfg.merge_from_file(OVSEG_CONFIG_PATH)
#     cfg.MODEL.WEIGHTS = OVSEG_WEIGHTS_PATH
#     cfg.freeze()

#     # ovseg_model = build_ovseg_model(cfg,OVSEG_WEIGHTS_PATH).to(device)
#     ovseg_model = build_ovseg_model(cfg, device=device)
#     ovseg_model.eval()

#     # Load config (geometric / mask params)
#     if dataset == "Replica":
#         last_idx = 2000
#         with open(path + "core_configs/config_Replica.yaml", "r") as f:
#             config = yaml.safe_load(f)
#     else:
#         last_idx = len(os.listdir(dataset_path + scene + "/color/"))
#         with open(path + "core_configs/config_ScanNet.yaml", "r") as f:
#             config = yaml.safe_load(f)

#     global resolution
#     geometery_overlap_thr1 = config["geometery_overlap_thr1"]
#     geometery_overlap_thr2 = config["geometery_overlap_thr2"]
#     resolution = config["resolution"]
#     remove_thr = config["remove_thr"]
#     side_thr = config["side_thr"]
#     iou_remove_thr = config["iou_remove_thr"]
#     ratio_thr = config["ratio_thr"]
#     extension_ratio_hide = config["extension_ratio_hide"]
#     extension_ratio_s = config["extension_ratio_s"]
#     extension_ratio_l = config["extension_ratio_l"]
#     extension_ratio_h = config["extension_ratio_h"]
#     area_threshold = config["area_threshold"]

#     scene_path = dataset_path + scene + "/"

#     if dataset == "Replica":
#         camera_intristics, scale = _load_depth_intrinsics(
#             path=dataset_path + "/cam_params.json", dataset=dataset
#         )
#     else:
#         camera_intristics, scale = _load_depth_intrinsics(
#             path=dataset_path + scene + "/intrinsic/intrinsic_depth.txt", dataset=dataset
#         )

#     colors, masks, points, embeddings = masks_detection(
#         extension_ratio_hide=extension_ratio_hide,
#         extension_ratio_s=extension_ratio_s,
#         extension_ratio_l=extension_ratio_l,
#         extension_ratio_h=extension_ratio_h,
#         path=scene_path,
#         last_idx=last_idx,
#         step=15,
#         remove_iou=iou_remove_thr,
#         remove_side_thr=remove_thr,
#         side_thr=side_thr,
#         dataset=dataset,
#         area_threshold=area_threshold,
#         ratio_thr=ratio_thr,
#         camera_intristics=camera_intristics,
#         scale=scale,
#         resolution=resolution,
#         ovseg_model=ovseg_model,
#     )

#     adjacency = geometry_overlap(points, resolution, geometery_overlap_thr1, geometery_overlap_thr2)
#     components = merge_points(adjacency)
#     new_colors, new_points, new_embeddings, count = object_embeddings(
#         components, points, embeddings, colors
#     )

#     df_points_to_ids = pd.DataFrame(columns=["x", "y", "z", "Object id"])
#     df_ids_to_embeddings = {}
#     for i in range(len(new_points)):
#         pts = np.unique(new_points[i], axis=0)
#         temp = pd.DataFrame(columns=["x", "y", "z", "Object id"])
#         temp["x"] = pts[:, 0]
#         temp["y"] = pts[:, 1]
#         temp["z"] = pts[:, 2]
#         temp["Object id"] = i

#         df_ids_to_embeddings[i] = {
#             "embedding": list(new_embeddings[i].astype(float)),
#             "count": count[i],
#         }
#         df_points_to_ids = pd.concat([df_points_to_ids, temp], ignore_index=True)

#     os.makedirs(os.path.join(path, "embeddings"), exist_ok=True)
#     with open(os.path.join(path, "embeddings", scene + "_ids_to_embeddings_ovseg1.json"), "w") as json_file:
#         json.dump(df_ids_to_embeddings, json_file, indent=4)

#     df_points_to_ids.to_csv(
#         os.path.join(path, "embeddings", scene + "_points_to_ids_ovseg1.csv"), index=False
#     )


# if __name__ == "__main__":
#     main()


import os
import numpy as np
import yaml
from PIL import Image

import cv2
import argparse
import json
import torch
import gc
import copy
import pandas as pd
import open_clip
from sklearn.cluster import DBSCAN
from semantic_sam import prepare_image, plot_results, build_semantic_sam, SemanticSamAutomaticMaskGenerator, SemanticSAMPredictor
import networkx as nx
import psutil

device = "cuda" if torch.cuda.is_available() else "cpu"


def _load_pose(path, idx, dataset):
    """
    Load camera pose (4x4 transformation matrix) for a given frame index.
    Supports Replica and ScanNet datasets.
    """
    if (dataset == 'Replica'):
        # For Replica, all poses are in a single 'traj.txt' file
        path = os.path.join(path, "traj.txt")
        with open(path, "r") as file:
            lines = file.readlines()
            if 0 <= idx < len(lines):
                line = lines[idx]
                # each line has 16 values (4x4 matrix flattened row-wise)
                values = [float(val) for val in line.split()]
                transformation_matrix = np.array(values).reshape((4, 4))
                return transformation_matrix
    elif (dataset == 'ScanNet'):
        # For ScanNet, each pose is stored in a separate txt file: "<idx>.txt"
        path = os.path.join(path, str(idx) + '.txt')
        transformation_matrix = np.loadtxt(path).reshape(4, 4)
        return transformation_matrix


def _load_depth_intrinsics(path, dataset):
    """
    Load depth camera intrinsics and scale factor for depth values.
    Different format for Replica vs ScanNet.
    """
    if (dataset == 'Replica'):
        # Replica intrinsics and scale come from a JSON file
        with open(path, "r") as file:
            data = json.load(file)
            camera_params = data.get("camera")
            if camera_params:
                w = camera_params.get("w")
                h = camera_params.get("h")
                fx = camera_params.get("fx")
                fy = camera_params.get("fy")
                cx = camera_params.get("cx")
                cy = camera_params.get("cy")
                scale = camera_params.get("scale")
                # Camera intrinsic matrix K
                K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
                K = np.array(K)
                return K, scale
    elif (dataset == 'ScanNet'):
        # ScanNet intrinsics: already in matrix form
        intrinsic_depth = np.loadtxt(path)
        # Depth values are typically in millimeters -> convert to meters by scale=1000
        scale = 1000.0
        return intrinsic_depth, scale


def overlap(mask_i, mask_j, iom_threshold):
    """
    Compute 'intersection over minimum' (IoM) between mask_i and mask_j.
    Return:
      0  -> no significant overlap
      1  -> mask_i mostly inside mask_j
      2  -> mask_j mostly inside mask_i
    """
    intersection = np.sum(np.multiply(mask_i, mask_j))
    sum_i = np.sum(mask_i)
    sum_j = np.sum(mask_j)
    if (sum_i > sum_j):
        # IoM denominator = smaller mask
        iom = intersection / sum_j
        if (iom > iom_threshold):
            return 2  # j is inside i
        else:
            return 0
    else:
        iom = intersection / sum_i
        if (iom > iom_threshold):
            return 1  # i is inside j
        else:
            return 0


def remove_overlapped_masks(results, iom_threshold, area_threshold):
    """
    Remove overlapping and tiny masks.
    - results: list of 2D binary masks (numpy arrays).
    - iom_threshold: IoM threshold for overlap pruning.
    - area_threshold: minimum area ratio to keep a mask.
    Returns:
      torch.Tensor of kept masks [N, H, W]
      torch.Tensor of corresponding boxes [N, 4] (xmin, ymin, xmax, ymax)
    """
    mask_shape = results[0].shape
    boxes = torch.zeros(len(results), 4)
    new_masks = torch.zeros(len(results), mask_shape[0], mask_shape[1])
    removed = []
    for i in range(len(results)):
        remove_i = []
        remove_mask_i = []
        mask_i = results[i].astype(int)

        # Area ratio relative to full image
        area_ratio = np.sum(mask_i) / (np.shape(mask_i)[0] * np.shape(mask_i)[1])

        if (area_ratio < area_threshold):
            # mask too small -> remove
            removed.append(i)

        else:
            # Check overlap with all other masks
            for j in range(len(results)):
                if (j != i):
                    mask_j = results[j].astype(int)
                    index = overlap(mask_i, mask_j, iom_threshold=iom_threshold)
                    # If j is inside i, we remove j from i
                    if (index == 2):
                        if (not (j in remove_i)):
                            remove_i.append(j)
                            remove_mask_i.append(mask_j)
            # Subtract overlapped masks from mask_i
            for mask in remove_mask_i:
                mask_i = mask_i * (1 - mask)
            if (np.sum(mask_i) > 0):
                new_masks[i] = torch.tensor(mask_i)

                # Compute bounding box from remaining mask_i
                rows, cols = np.where(mask_i)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()
                box = [min_col, min_row, max_col, max_row]
                boxes[i] = torch.tensor(box)

    the_masks = []
    the_boxes = []
    # Collect only non-removed and non-empty masks
    for i in range(new_masks.size()[0]):
        mask = new_masks[i]
        box = boxes[i]

        if (not (i in removed) and torch.sum(mask) > 0):
            the_masks.append(mask)
            the_boxes.append(box)

    return torch.stack(the_masks), torch.stack(the_boxes)


def remove_tiny_masks(masks, boxes, area_threshold):
    """
    Remove masks whose area ratio is below area_threshold.
    Input:
      masks: list of (H, W) tensors
      boxes: list of [4] tensors
    """
    the_masks = []
    the_boxes = []
    for i in range(len(masks)):

        mask = masks[i]
        box = boxes[i]
        area_ratio = torch.sum(mask) / ((mask.size())[0] * (mask.size())[1])

        if (area_ratio > area_threshold):
            the_masks.append(mask)
            the_boxes.append(box)

    return the_masks, the_boxes


def dbscan_mask_denoise(mask, eps, min_samples):
    """
    Apply 2D DBSCAN on mask pixels to remove noise and split into clusters.
    Returns list of binary masks (one per cluster).
    """
    # Coordinates of non-zero pixels
    coords = np.column_stack(np.nonzero(mask))  # shape: (N, 2)

    if coords.shape[0] == 0:
        # no pixels -> return original mask (or empty)
        return mask

    # Cluster using DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(coords)

    labels = clustering.labels_  # -1 indicates noise

    # Build a label image initialized with -1
    cleaned_mask = -1 * np.ones_like(mask)
    for cluster_id in np.unique(labels):
        if cluster_id == -1:
            continue  # skip noise cluster
        cluster_points = coords[labels == cluster_id]
        cleaned_mask[cluster_points[:, 0], cluster_points[:, 1]] = cluster_id

    final_masks = []
    # Convert each cluster to an individual binary mask
    for i in range(len(np.unique(cleaned_mask))):
        if (i != -1):
            final_masks.append((cleaned_mask == i).astype(int))

    return final_masks


def dbscan_3d(points, colors, eps, min_samples):
    """
    Apply DBSCAN in 3D on point cloud to find a dominant cluster.
    Returns:
      cluster_points, cluster_colors, threshold (cluster_size / total_size)
      If memory usage estimate too high, returns empty and threshold -1.
    """
    points = points.astype(np.float16)
    dtype = type(points[0])
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    # Memory estimation for all pairwise distances (worst-case)
    n_samples = np.shape(points)[0]
    bytes_per_entry = np.dtype(dtype).itemsize
    total_bytes = n_samples ** 2 * bytes_per_entry
    # If estimated memory exceeds ~600 GB, skip
    if (total_bytes / (1024 ** 3) > 600):
        return [], [], -1

    # Run 3D DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
    labels = clustering.labels_

    # Here they take only the cluster with label 0 (assumed main cluster)
    cluster_points = points[labels == 0]
    cluster_colors = colors[labels == 0]
    threshold = cluster_points.shape[0] / points.shape[0]

    return cluster_points, cluster_colors, threshold


def remove_side_masks(masks, boxes, remove_thr, side_thr, ratio_thr):
    """
    Remove masks that lie mostly on the image borders (side artifacts).
    - remove_thr: ratio of mask pixels near edge needed to remove.
    - side_thr: pixel distance from edge to be considered 'side region'.
    - ratio_thr: overall area ratio threshold to ignore huge masks.
    """
    removed = []
    mask_shape = masks[0].size()

    for i in range(len(masks)):
        mask = masks[i]
        H, W = mask.shape

        # Coordinate grid
        y_coords = torch.arange(H).view(-1, 1).expand(H, W)
        x_coords = torch.arange(W).view(1, -1).expand(H, W)

        # Distance to image borders
        dist_top = y_coords
        dist_bottom = H - 1 - y_coords
        dist_left = x_coords
        dist_right = W - 1 - x_coords

        dist_to_edge = torch.minimum(torch.minimum(dist_top, dist_bottom),
                                     torch.minimum(dist_left, dist_right))

        # Edge region mask (within side_thr pixels from any border)
        edge_mask = (dist_to_edge < side_thr).long()

        # Intersect region of interest with object mask
        final_mask = mask * edge_mask

        # Number of 'edge' pixels inside mask
        count = final_mask.sum().item()
        sum_mask = mask.sum().item()

        ratio = count / sum_mask  # fraction of mask at border
        ratio1 = sum_mask / (final_mask.size()[0] * final_mask.size()[1])  # total area ratio

        # If object mostly at the border and relatively small, remove it
        if (ratio > remove_thr and ratio1 < ratio_thr):
            removed.append(i)

    new_masks = []
    new_boxes = []
    j = 0
    # Keep only non-removed masks
    for i in range(len(masks)):
        mask = masks[i]
        box = boxes[i]
        if (not (i in removed)):
            new_masks.append(np.array(mask))
            new_boxes.append(np.array(box))
            j += 1
    return new_masks, new_boxes


def extend_images(image, boxes, masks, extension_ratio, hide_mask=False, hide_others=False):
    """
    Crop & resize regions around each bounding box, with optional masking:
      - hide_mask=True: hide the object itself (zero out its region).
      - hide_others=True: keep only the object, hide everything else.
    Then extend box by 'extension_ratio' and resize to fixed sizes.
    Returns:
      list of cropped images, list of mask coverage ratios per crop.
    """
    extended_images = []
    ratios = []
    for i in range(len(boxes)):

        new_image = copy.deepcopy(image)
        if (hide_mask):
            # Keep everything except the mask (set masked areas to zero)
            new_image = (np.array(1 - masks[i])[:, :, np.newaxis]) * new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)
        if (hide_others):
            # Keep only the masked object, set others to zero
            new_image = (np.array(masks[i])[:, :, np.newaxis]) * new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)

        # Center and half width/height
        center_x = ((boxes[i][2] + boxes[i][0]) / 2)
        center_y = ((boxes[i][3] + boxes[i][1]) / 2)
        width = ((boxes[i][2] - boxes[i][0]) / 2)
        height = ((boxes[i][3] - boxes[i][1]) / 2)

        # Extended crop coordinates
        new_x1 = int(center_x - extension_ratio * width)
        new_x2 = int(center_x + extension_ratio * width)
        new_y1 = int(center_y - extension_ratio * height)
        new_y2 = int(center_y + extension_ratio * height)
        x_margin = image.shape[1]
        y_margin = image.shape[0]

        # Clamp to image boundaries
        if (new_x1 < 0):
            new_x1 = 0
        if (new_x2 > x_margin):
            new_x2 = x_margin
        if (new_y1 < 0):
            new_y1 = 0
        if (new_y2 > y_margin):
            new_y2 = y_margin

        # Crop the extended region
        final_image = new_image[new_y1:new_y2, new_x1:new_x2]

        # Resize crop keeping approximate aspect ratio
        if (width > height):
            final_image = cv2.resize(final_image, (1200, int(1200 / width * height)))
        else:
            final_image = cv2.resize(final_image, (int(800 / height * width), 800))

        extended_images.append(final_image)
        # Ratio of mask pixels inside the extended crop
        ratio = np.sum(masks[i]) / ((new_y2 - new_y1) * (new_x2 - new_x1))
        ratios.append(ratio.item())

    return extended_images, ratios


def grid_embedding(masks, embeddings):
    """
    Broadcast per-object embeddings to a spatial grid and
    collect embeddings at mask locations.
    - masks: [N, H, W]
    - embeddings: [N, D]
    Returns: flattened embeddings of all masked pixels.
    """
    embeddings = torch.tensor(embeddings)
    embeddings = embeddings.view(-1, 1, 1, embeddings.size()[1])
    embeddings = torch.Tensor.repeat(embeddings, (1, masks.size()[1], masks.size()[2], 1))
    # Select embedding vectors where mask > 0
    embeddings = embeddings[masks > 0]
    return embeddings


def point_cloud(depth, scale, camera_intristics, mask, camera_pose, colors, dataset):
    """
    Convert depth + intrinsics + camera pose + mask into a 3D point cloud.
    Returns:
      points: (N, 3) in world coords
      colors: (N, 3) RGB values
    Only points where mask>0 and depth>0 are kept.
    """
    mask = mask.to('cuda')
    camera_matrix = torch.tensor(camera_intristics).to('cuda')
    depth = torch.tensor(depth, dtype=torch.float32).to('cuda')
    colors = torch.tensor(colors).to('cuda')
    camera_pose = torch.tensor(camera_pose).to('cuda')

    # Pixel grid (y, x)
    y, x = torch.meshgrid(torch.arange(depth.size()[0]), torch.arange(depth.size()[1]), indexing='ij')
    x = x.to('cuda')
    y = y.to('cuda')

    # If depth has extra channel, squeeze it
    if (depth.dim() == 3):
        depth = depth[:, :, 0]

    # Convert raw depth using scale (e.g., to meters)
    depth = depth.float() / scale

    # Mask out invalid depths
    depth_mask1 = (depth > 0).long()
    mask = mask * depth_mask1
    mask = mask > 0

    # Back-project to 3D camera coordinates
    X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
    Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
    Z = depth

    # Homogeneous coordinates
    points = torch.stack((X.view(-1), Y.view(-1), Z.view(-1), torch.ones_like(X.view(-1))), dim=-1)

    # Transform to world coordinates via camera pose
    points = torch.matmul(camera_pose, points.T)
    points = points.T

    # Keep only masked points
    points = points[mask.view(-1)]
    colors = colors[mask]
    colors = colors.view(-1, 3)
    points = points.view(-1, 4)

    return points[:, :3].cpu().numpy(), colors.cpu().numpy()


def clip_image(images, model, preprocess):
    """
    Encode a list of RGB images into CLIP embeddings using a provided model+preprocess.
    Returns: list of numpy arrays (one embedding per image).
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    embeddings = []
    for i in range(len(images)):
        processed_image = preprocess(Image.fromarray(images[i])).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(processed_image)
        embeddings.append(image_features.squeeze(0).cpu().numpy())
    return embeddings


def masks_detection(extension_ratio_hide, extension_ratio_s, extension_ratio_l, extension_ratio_h, path, last_idx, step, remove_iou, remove_side_thr, side_thr, dataset, area_threshold, ratio_thr,
                    camera_intristics, scale, resolution, clip_model, preprocess, alpha_h, alpha_l, alpha_o, alpha_m, semantic_sam_path):
    """
    Main per-frame processing pipeline:
    - Load RGB, depth, pose
    - Run Semantic SAM at multiple levels to get masks
    - Filter masks (overlap, sides, tiny areas)
    - Refine masks with 2D DBSCAN
    - Build 3D point cloud for each mask, cluster with 3D DBSCAN
    - Crop extended regions, compute CLIP embeddings via mask_embedding()
    Returns aggregated:
      my_colors, my_masks, my_points, my_embeddings
    """

    if (dataset == 'Replica'):
        prefix_rgb = 'results/frame'
        prefix_depth = 'results/depth'
        pose_path = path
        # List all frame files starting with 'frame'
        image_list = [f for f in os.listdir(path + 'results/')
                      if f.startswith('frame') and os.path.isfile(os.path.join(path + 'results/', f))]
    elif (dataset == 'ScanNet'):
        # For ScanNet, 'color' and 'depth' folders are used
        prefix_rgb = 'color/'
        prefix_depth = 'depth/'
        pose_path = path + 'pose/'
        image_list = os.listdir(path + prefix_rgb)

    # Build three Semantic SAM mask generators at different levels (scales)
    mask_generator1 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L', ckpt=semantic_sam_path), level=[3])
    mask_generator2 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L', ckpt=semantic_sam_path), level=[4])
    mask_generator3 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L', ckpt=semantic_sam_path), level=[6])

    my_colors = []
    my_masks = []
    my_points = []
    my_embeddings = []

    # Process every 'step'-th frame up to last_idx
    for i in range(0, last_idx, step):

        print('Processing frame {}/{}'.format(i, len(image_list)))
        image_path = image_list[i]

        if (dataset == 'Replica'):
            # Extract numeric index from 'frameXXXX.jpg'
            idx_path = str(image_path[image_path.index('e') + 1:image_path.index('.')])
        elif (dataset == 'ScanNet'):
            # For ScanNet, file name before '.' is index
            idx_path = str(image_path[:image_path.index('.')])

        # Read RGB and depth
        rgb_i = cv2.imread(path + prefix_rgb + idx_path + '.jpg')
        if (dataset == 'Replica'):
            depth_i = cv2.imread(path + prefix_depth + idx_path + '.png', cv2.IMREAD_UNCHANGED).astype(np.double)
        elif (dataset == 'ScanNet'):
            depth_i = cv2.imread(path + prefix_depth + idx_path + '.png', cv2.IMREAD_UNCHANGED).astype(np.float32)

        # Load pose
        camera_pose_i = _load_pose(pose_path, int(idx_path), dataset)

        # Ensure RGB & depth have same spatial resolution
        if rgb_i.shape[:2] != depth_i.shape[:2]:
            rgb_i = cv2.resize(rgb_i, (depth_i.shape[1], depth_i.shape[0]), interpolation=cv2.INTER_LINEAR)

        # PIL image for Semantic SAM input
        im_i = Image.open(path + prefix_rgb + idx_path + '.jpg')
        im_i = im_i.resize((depth_i.shape[1], depth_i.shape[0]))

        results_i = []
        original_image, input_image = prepare_image(image_pth=path + prefix_rgb + idx_path + '.jpg')

        # Level-3 masks
        the_masks1 = mask_generator1.generate(input_image)
        # Start with a full 1 mask for 'remaining_parts'
        remaining_parts = np.ones_like(the_masks1[0]['segmentation'].astype(int))
        remaining_parts = cv2.resize(remaining_parts, (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
        new_remaining_parts = copy.deepcopy(remaining_parts)

        # Filter out masks that do not cover enough unseen area
        for k in range(len(the_masks1)):
            mask = the_masks1[k]['segmentation']
            mask = cv2.resize(mask.astype(int), (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
            # Condition on overlap with remaining_parts (at least 65% of mask)
            if (np.sum(mask * new_remaining_parts) / np.sum(mask) > 0.65):
                results_i.append(mask)

        # Remove overlapping & tiny masks, then side masks
        the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)

        # Update remaining_parts using accepted masks
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask, dtype=np.int32))

        del the_masks1
        del the_masks_i
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)

        # Level-4 masks
        the_masks2 = mask_generator2.generate(input_image)
        for k in range(len(the_masks2)):
            mask = the_masks2[k]['segmentation']
            mask = cv2.resize(mask.astype(int), (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
            if (np.sum(mask * remaining_parts) / np.sum(mask) > 0.45):
                results_i.append(mask)

        the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask, dtype=np.int32))

        del the_masks2
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)

        # Level-6 masks
        the_masks3 = mask_generator3.generate(input_image)
        for k in range(len(the_masks3)):
            mask = the_masks3[k]['segmentation']
            mask = cv2.resize(mask.astype(int), (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
            if (np.sum(mask * remaining_parts) / np.sum(mask) > 0.35):
                results_i.append(mask)

        del mask
        del the_masks3
        gc.collect()

        # Final overlap & side filtering
        the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
        the_masks_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)

        masks_i = []
        boxes_i = []

        # For each mask, denoise with 2D DBSCAN and compute bounding boxes
        for k in range(len(the_masks_i)):
            mask = np.array(the_masks_i[k])
            mask = cv2.resize(mask.astype(int), (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
            new_masks = dbscan_mask_denoise(mask, 7, 80)

            for r in range(len(new_masks)):
                my_mask = new_masks[r]
                if (np.sum(my_mask) > 0):
                    masks_i.append(torch.tensor(my_mask))
                    rows, cols = np.where(my_mask > 0)
                    min_row, max_row = rows.min(), rows.max()
                    min_col, max_col = cols.min(), cols.max()
                    box = [min_col, min_row, max_col, max_row]
                    boxes_i.append(torch.tensor(box))

            del new_masks
            gc.collect()

        # Remove very small masks
        masks_i, boxes_i = remove_tiny_masks(masks_i, boxes_i, area_threshold)

        new_masks_i = []
        new_boxes_i = []
        new_points_i = []
        new_colors_i = []
        thresholds = []

        # For each remaining mask, build 3D point cloud and cluster in 3D
        for k in range(len(masks_i)):

            points, colors = point_cloud(depth_i, scale, camera_intristics, masks_i[k], camera_pose_i, rgb_i, dataset)
            # Snap points to 3D grid (quantization)
            points = points_to_grid(points, resolution)

            if (points.shape[0] > 0):
                # DBSCAN in 3D
                points, colors, threshold = dbscan_3d(points, colors, 0.15, 200)
            else:
                threshold = 0.0

            # Only keep clusters that are large enough (threshold > 0.8)
            if (threshold > 0.8):

                thresholds.append(threshold)
                the_mask = masks_i[k]
                rows, cols = np.where(the_mask > 0)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()

                box = [min_col, min_row, max_col, max_row]
                new_masks_i.append(masks_i[k].cpu().numpy())
                new_boxes_i.append(np.array(box))
                new_points_i.append(points)
                new_colors_i.append(colors)

            del points
            del colors
            gc.collect()

        # Build different visual crops (small/large/hide/etc.) to feed CLIP
        extended_s_i, ratio_s_i = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_s)
        extended_l_i, ratio_l_i = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_l)
        extended_h_i, ratio_h_i = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_h)
        hide_i, ratios_i = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_hide, True)
        extended_mask_i, ratio_m_i = extend_images(rgb_i, new_boxes_i, new_masks_i, 1.0, False, True)

        # Get combined embedding per object
        embedding = mask_embedding(
            extended_s_i,
            extended_l_i,
            extended_h_i,
            hide_i,
            extended_mask_i,
            alpha_h,
            alpha_l,
            alpha_o,
            alpha_m,
            clip_model,
            preprocess
        )

        my_colors.append(new_colors_i)
        my_masks.append(new_masks_i)
        my_points.append(new_points_i)
        my_embeddings.append(embedding)

        # Cleanup
        del embedding
        del extended_s_i
        del extended_l_i
        del extended_h_i
        del hide_i
        del extended_mask_i

    # Flatten per-frame lists into single lists
    final_masks = []
    final_points = []
    final_embeddings = []
    final_colors = []
    for i in range(len(my_masks)):
        for j in range(len(my_masks[i])):
            final_masks.append(my_masks[i][j])
            final_points.append(my_points[i][j])
            final_embeddings.append(my_embeddings[i][j])
            final_colors.append(my_colors[i][j])
    return final_colors, final_masks, final_points, final_embeddings


def grid_indices(points, params, res):
    """
    Convert 3D points into voxel grid indices.
    NOTE: this function assumes a global 'device' & 'resolution' variable.
    """
    points = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(torch.tensor(points))
    indices[:, 0] = ((points[:, 0] - x_min) / res).long()
    indices[:, 1] = ((points[:, 1] - y_min) / res).long()
    indices[:, 2] = ((points[:, 2] - z_min) / res).long()
    return indices.long()


def voxelize_batch(points, X, Y, Z, params):
    """
    Convert a point cloud into a (X,Y,Z) voxel grid.
    Output:
      voxel: binary occupancy grid.
    NOTE: uses global 'device' and 'resolution'.
    """

    voxel = torch.zeros((X, Y, Z), device=device)
    indices = grid_indices(points, params, resolution)
    voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1

    return voxel


def geometry_overlap(points, resolution, overlap_thr1, overlap_thr2):
    """
    Compute pairwise geometric overlap between all point clouds.
    - points: list of point clouds (each (N_i, 3))
    - resolution: voxel size
    - overlap_thr1: minimum overlap ratio per object
    - overlap_thr2: max allowed difference between overlaps_i and overlaps_j
    Returns:
      adjacency: NxN matrix, adjacency[i,j]=1 if two objects overlap enough.
    """
    num_points = len(points)
    adjacency = np.zeros((num_points, num_points))
    x_min = np.zeros(num_points)
    y_min = np.zeros(num_points)
    z_min = np.zeros(num_points)
    x_max = np.zeros(num_points)
    y_max = np.zeros(num_points)
    z_max = np.zeros(num_points)

    # Compute bounding boxes for each point cloud
    for i in range(num_points):
        x_min[i] = np.min(points[i][:, 0])
        y_min[i] = np.min(points[i][:, 1])
        z_min[i] = np.min(points[i][:, 2])
        x_max[i] = np.max(points[i][:, 0])
        y_max[i] = np.max(points[i][:, 1])
        z_max[i] = np.max(points[i][:, 2])

    # Compare all pairs
    for i in range(0, num_points):
        print(i, ' point clouds processed from ', num_points, '!')

        for j in range(0, num_points):
            # Quickly discard non-overlapping AABBs
            if (not ((x_min[i] > x_max[j] or x_max[i] < x_min[j]) or
                     (y_min[i] > y_max[j] or y_max[j] < y_min[j]) or
                     (z_min[i] > z_max[j] or z_max[j] < z_min[j]))):

                # Build union bounding box with small padding
                my_x_min = np.min([x_min[i], x_min[j]]) - 0.2
                my_y_min = np.min([y_min[i], y_min[j]]) - 0.2
                my_z_min = np.min([z_min[i], z_min[j]]) - 0.2
                my_x_max = np.max([x_max[i], x_max[j]]) + 0.2
                my_y_max = np.max([y_max[i], y_max[j]]) + 0.2
                my_z_max = np.max([z_max[i], z_max[j]]) + 0.2
                X = int((my_x_max - my_x_min) / resolution)
                Y = int((my_y_max - my_y_min) / resolution)
                Z = int((my_z_max - my_z_min) / resolution)
                params = (my_x_min, my_y_min, my_z_min, my_x_max, my_y_max, my_z_max)

                # Voxelize both point clouds into same grid
                v_i = voxelize_batch(points[i], X, Y, Z, params)
                v_j = voxelize_batch(points[j], X, Y, Z, params)

                overlaps = torch.sum(v_i * v_j)
                overlaps_i = overlaps / torch.sum(v_i)
                overlaps_j = overlaps / torch.sum(v_j)

                # If both overlap ratios are high and similar -> mark as adjacent
                if (overlaps_i > overlap_thr1) and (overlaps_j > overlap_thr1) and (
                        torch.abs(overlaps_j - overlaps_i) < overlap_thr2):
                    adjacency[i, j] = 1

                del v_j, overlaps_i, overlaps_j, overlaps
                torch.cuda.empty_cache()

    return adjacency


def merge_points(adj_matrix):
    """
    Given adjacency matrix, find connected components (merged object clusters).
    """
    G = nx.from_numpy_array(adj_matrix)
    components = list(nx.connected_components(G))
    return components


def mask_embedding(object_extend_s, object_extend_l, object_extend_h, object_extend_hide, object_extend_mask,
                   alpha_h, alpha_l, alpha_o, alpha_m, clip_model, clip_path):
    """
    Compute a final embedding for an object by combining multiple views:
    - small crop
    - large crop
    - heavily extended crop
    - crop with object hidden
    - mask-focused crop
    Weighted combination:
      final = alpha_h*E_h + alpha_l*E_l + E_s - alpha_o*E_hide + alpha_m*E_mask
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embeddings_s = torch.tensor(clip_image(object_extend_s, clip_model, clip_path)).to(device)
    embeddings_l = torch.tensor(clip_image(object_extend_l, clip_model, clip_path)).to(device)
    embeddings_h = torch.tensor(clip_image(object_extend_h, clip_model, clip_path)).to(device)
    embeddings_mask = torch.tensor(clip_image(object_extend_mask, clip_model, clip_path)).to(device)
    embeddings_hide = torch.tensor(clip_image(object_extend_hide, clip_model, clip_path)).to(device)
    final_embedding = alpha_h * embeddings_h + alpha_l * embeddings_l + embeddings_s - alpha_o * embeddings_hide + alpha_m * embeddings_mask

    return final_embedding.cpu().numpy()


def points_to_grid(points, resolution):
    """
    Snap 3D points to a 3D grid with cell size = resolution.
    """
    converted_points = ((points / resolution).astype(int).astype(np.float32)) * resolution
    return converted_points


def object_embeddings(components, points, embeddings, colors):
    """
    Aggregate points, colors and embeddings over connected components.
    - components: list of sets of point-cloud indices
    Returns:
      final_colors: list of stacked colors per component
      final_points: list of stacked points per component
      final_embeddings: mean embedding per component
      count: number of sub-objects merged into each component
    """
    final_points = []
    final_colors = []
    final_embeddings = []
    count = []
    for i in range(len(components)):
        new_points = []
        new_colors = []
        new_embeddings = []
        for j in components[i]:
            new_points.append(points[j])
            new_embeddings.append(embeddings[j])
            new_colors.append(colors[j])
        final_colors.append(np.vstack(new_colors))
        final_points.append(np.vstack(new_points))
        final_embeddings.append(np.mean(np.array(new_embeddings), axis=0))
        count.append(len(new_embeddings))
    return final_colors, final_points, final_embeddings, count


def objects_class(embeddings, classes):
    """
    Given embeddings (e.g. similarity scores) and a list of classes,
    choose the class with the maximum score for each object.
    """
    final_classes = []
    for i in range(len(embeddings)):
        final_index = np.argsort(embeddings[i])
        final_class_obj = np.array(classes)[final_index]
        final_classes.append(final_class_obj[-1])
    return final_classes


def main():
    """
    Main entry point:
    - parse args
    - load config
    - run mask detection and 3D grouping
    - save maps: point -> object_id and object_id -> embedding
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("--scene", type=str)
    parser.add_argument("--dataset_path", type=str)
    # OVSeg-finetuned OpenCLIP uses ViT-H-14 architecture
    parser.add_argument("--clip_model", type=str, default='ViT-L/14@336px')
    parser.add_argument("--path", type=str, default='')
    parser.add_argument("--dataset", type=str)
    args = parser.parse_args()

    path = args.path
    dataset = args.dataset
    scene = args.scene
    dataset_path = args.dataset_path
    clip_name = args.clip_model

    # CLIP model weights (OVSeg-CLIP) and Semantic SAM weights
    # Make sure you have downloaded ovseg_clip.pth to this location
    clip_path = path + 'models/ovseg_clip.pth'
    semantic_sam_path = 'models/swinl_only_sam_many2many.pth'

    clip_name = "ViT-L-14"
    pretrained = "models/ovseg_clip.pth"
    # Create CLIP model & transforms (OVSeg uses ViT-H-14 OpenCLIP)
    clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_name, pretrained)

    ckpt = torch.load(clip_path, map_location='cpu')
    # Handle both plain and wrapped state dict formats
    if isinstance(ckpt, dict) and 'state_dict' in ckpt:
        state_dict = ckpt['state_dict']
    else:
        state_dict = ckpt
    clip_model.load_state_dict(state_dict, strict=False)

    # Load config file and last_idx (#frames)
    if (dataset == 'Replica'):
        last_idx = 2000
        # last_idx = 15
        with open(path + "core_configs/config_Replica.yaml", "r") as f:
            config = yaml.safe_load(f)
    else:
        last_idx = len(os.listdir(dataset_path + scene + '/color/'))
        with open(path + "core_configs/config_ScanNet.yaml", "r") as f:
            config = yaml.safe_load(f)

    global resolution
    # Read important thresholds and parameters from YAML
    geometery_overlap_thr1 = config['geometery_overlap_thr1']
    geometery_overlap_thr2 = config['geometery_overlap_thr2']
    resolution = config['resolution']
    remove_thr = config['remove_thr']
    side_thr = config['side_thr']
    iou_remove_thr = config['iou_remove_thr']
    ratio_thr = config['ratio_thr']
    extension_ratio_hide = config['extension_ratio_hide']
    extension_ratio_s = config['extension_ratio_s']
    extension_ratio_l = config['extension_ratio_l']
    extension_ratio_h = config['extension_ratio_h']
    area_threshold = config['area_threshold']
    alpha_h = config['alpha_h']
    alpha_l = config['alpha_l']
    alpha_o = config['alpha_o']
    alpha_m = config['alpha_m']

    # Path to the scene data
    scene_path = dataset_path + scene + '/'

    # Load intrinsics and depth scale
    if (dataset == 'Replica'):
        camera_intristics, scale = _load_depth_intrinsics(path=dataset_path + '/cam_params.json', dataset=dataset)
    else:
        camera_intristics, scale = _load_depth_intrinsics(path=dataset_path + scene + '/intrinsic/intrinsic_depth.txt',
                                                          dataset=dataset)

    # Run the full detection+embedding pipeline
    colors, masks, points, embeddings = masks_detection(
        extension_ratio_hide=extension_ratio_hide,
        extension_ratio_s=extension_ratio_s,
        extension_ratio_l=extension_ratio_l,
        extension_ratio_h=extension_ratio_h,
        path=scene_path,
        last_idx=last_idx,
        step=15,
        remove_iou=iou_remove_thr,
        remove_side_thr=remove_thr,
        side_thr=side_thr,
        dataset=dataset,
        area_threshold=area_threshold,
        ratio_thr=ratio_thr,
        camera_intristics=camera_intristics,
        scale=scale,
        resolution=resolution,
        clip_model=clip_model,
        preprocess=preprocess,
        alpha_h=alpha_h,
        alpha_l=alpha_l,
        alpha_o=alpha_o,
        alpha_m=alpha_m,
        semantic_sam_path=semantic_sam_path
    )

    # Compute adjacency matrix based on 3D geometric overlap
    adjacency = geometry_overlap(points, resolution, geometery_overlap_thr1, geometery_overlap_thr2)

    # Merge objects into connected components
    components = merge_points(adjacency)

    # Aggregate embeddings & points for merged objects
    new_colors, new_points, new_embeddings, count = object_embeddings(components, points, embeddings, colors)

    # DataFrame mapping 3D points -> object ids
    df_points_to_ids = pd.DataFrame(columns=['x', 'y', 'z', 'Object id'])

    # Dictionary mapping object id -> embedding and count
    df_ids_to_embeddings = {}
    for i in range(len(new_points)):
        pts = np.unique(new_points[i], axis=0)
        temp = pd.DataFrame(columns=['x', 'y', 'z', 'Object id'])
        temp['x'] = pts[:, 0]
        temp['y'] = pts[:, 1]
        temp['z'] = pts[:, 2]
        temp['Object id'] = i

        df_ids_to_embeddings[i] = {'embedding': list(new_embeddings[i].astype(float)), 'count': count[i]}
        df_points_to_ids = pd.concat([df_points_to_ids, temp], ignore_index=True)

    # Save embeddings and point-to-id mapping
    os.makedirs(path + 'embeddings/', exist_ok=True)
    with open(path + 'embeddings/' + scene + '_ids_to_embeddings_ov_scannet.json', "w") as json_file:
        json.dump(df_ids_to_embeddings, json_file, indent=4)

    df_points_to_ids.to_csv(path + 'embeddings/' + scene + '_points_to_ids_ov_scannet.csv', index=False)


if __name__ == '__main__':
    main()
