import os
import numpy as np
import yaml
from PIL import Image

import cv2
import argparse
import json
import torch
import gc
import copy
import pandas as pd
import open_clip
from sklearn.cluster import DBSCAN
from semantic_sam import (
    prepare_image,
    plot_results,
    build_semantic_sam,
    SemanticSamAutomaticMaskGenerator,
    SemanticSAMPredictor,
)
import networkx as nx
import psutil

device = "cuda" if torch.cuda.is_available() else "cpu"


def _load_pose(path, idx, dataset):
    """
    Load camera pose (4x4 transformation matrix) for a given frame index.
    Supports Replica and ScanNet datasets.
    """
    if dataset == "Replica":
        # For Replica, all poses are in a single 'traj.txt' file
        path = os.path.join(path, "traj.txt")
        with open(path, "r") as file:
            lines = file.readlines()
            if 0 <= idx < len(lines):
                line = lines[idx]
                # each line has 16 values (4x4 matrix flattened row-wise)
                values = [float(val) for val in line.split()]
                transformation_matrix = np.array(values).reshape((4, 4))
                return transformation_matrix
    elif dataset == "ScanNet":
        # For ScanNet, each pose is stored in a separate txt file: "<idx>.txt"
        path = os.path.join(path, str(idx) + ".txt")
        transformation_matrix = np.loadtxt(path).reshape(4, 4)
        return transformation_matrix


def _load_depth_intrinsics(path, dataset):
    """
    Load depth camera intrinsics and scale factor for depth values.
    Different format for Replica vs ScanNet.
    """
    if dataset == "Replica":
        # Replica intrinsics and scale come from a JSON file
        with open(path, "r") as file:
            data = json.load(file)
            camera_params = data.get("camera")
            if camera_params:
                w = camera_params.get("w")
                h = camera_params.get("h")
                fx = camera_params.get("fx")
                fy = camera_params.get("fy")
                cx = camera_params.get("cx")
                cy = camera_params.get("cy")
                scale = camera_params.get("scale")
                # Camera intrinsic matrix K
                K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
                K = np.array(K)
                return K, scale
    elif dataset == "ScanNet":
        # ScanNet intrinsics: already in matrix form
        intrinsic_depth = np.loadtxt(path)
        # Depth values are typically in millimeters -> convert to meters by scale=1000
        scale = 1000.0
        return intrinsic_depth, scale


def overlap(mask_i, mask_j, iom_threshold):
    """
    Compute 'intersection over minimum' (IoM) between mask_i and mask_j.
    Return:
      0  -> no significant overlap
      1  -> mask_i mostly inside mask_j
      2  -> mask_j mostly inside mask_i
    """
    intersection = np.sum(np.multiply(mask_i, mask_j))
    sum_i = np.sum(mask_i)
    sum_j = np.sum(mask_j)
    if sum_i > sum_j:
        # IoM denominator = smaller mask
        iom = intersection / sum_j
        if iom > iom_threshold:
            return 2  # j is inside i
        else:
            return 0
    else:
        iom = intersection / sum_i
        if iom > iom_threshold:
            return 1  # i is inside j
        else:
            return 0


def remove_overlapped_masks(results, iom_threshold, area_threshold):
    """
    Remove overlapping and tiny masks.
    - results: list of 2D binary masks (numpy arrays).
    - iom_threshold: IoM threshold for overlap pruning.
    - area_threshold: minimum area ratio to keep a mask.
    Returns:
      torch.Tensor of kept masks [N, H, W]
      torch.Tensor of corresponding boxes [N, 4] (xmin, ymin, xmax, ymax)
    """
    mask_shape = results[0].shape
    boxes = torch.zeros(len(results), 4)
    new_masks = torch.zeros(len(results), mask_shape[0], mask_shape[1])
    removed = []
    for i in range(len(results)):
        remove_i = []
        remove_mask_i = []
        mask_i = results[i].astype(int)

        # Area ratio relative to full image
        area_ratio = np.sum(mask_i) / (np.shape(mask_i)[0] * np.shape(mask_i)[1])

        if area_ratio < area_threshold:
            # mask too small -> remove
            removed.append(i)

        else:
            # Check overlap with all other masks
            for j in range(len(results)):
                if j != i:
                    mask_j = results[j].astype(int)
                    index = overlap(mask_i, mask_j, iom_threshold=iom_threshold)
                    # If j is inside i, we remove j from i
                    if index == 2:
                        if not (j in remove_i):
                            remove_i.append(j)
                            remove_mask_i.append(mask_j)
            # Subtract overlapped masks from mask_i
            for mask in remove_mask_i:
                mask_i = mask_i * (1 - mask)
            if np.sum(mask_i) > 0:
                new_masks[i] = torch.tensor(mask_i)

                # Compute bounding box from remaining mask_i
                rows, cols = np.where(mask_i)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()
                box = [min_col, min_row, max_col, max_row]
                boxes[i] = torch.tensor(box)

    the_masks = []
    the_boxes = []
    # Collect only non-removed and non-empty masks
    for i in range(new_masks.size()[0]):
        mask = new_masks[i]
        box = boxes[i]

        if not (i in removed) and torch.sum(mask) > 0:
            the_masks.append(mask)
            the_boxes.append(box)

    return torch.stack(the_masks), torch.stack(the_boxes)


def remove_tiny_masks(masks, boxes, area_threshold):
    """
    Remove masks whose area ratio is below area_threshold.
    Input:
      masks: list of (H, W) tensors
      boxes: list of [4] tensors
    """
    the_masks = []
    the_boxes = []
    for i in range(len(masks)):
        mask = masks[i]
        box = boxes[i]
        area_ratio = torch.sum(mask) / (mask.size()[0] * mask.size()[1])

        if area_ratio > area_threshold:
            the_masks.append(mask)
            the_boxes.append(box)

    return the_masks, the_boxes


def dbscan_mask_denoise(mask, eps, min_samples):
    """
    Apply 2D DBSCAN on mask pixels to remove noise and split into clusters.
    Returns list of binary masks (one per cluster).
    """
    # Coordinates of non-zero pixels
    coords = np.column_stack(np.nonzero(mask))  # shape: (N, 2)

    if coords.shape[0] == 0:
        # no pixels -> return original mask (or empty)
        return mask

    # Cluster using DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(coords)

    labels = clustering.labels_  # -1 indicates noise

    # Build a label image initialized with -1
    cleaned_mask = -1 * np.ones_like(mask)
    for cluster_id in np.unique(labels):
        if cluster_id == -1:
            continue  # skip noise cluster
        cluster_points = coords[labels == cluster_id]
        cleaned_mask[cluster_points[:, 0], cluster_points[:, 1]] = cluster_id

    final_masks = []
    # Convert each cluster to an individual binary mask
    for i in range(len(np.unique(cleaned_mask))):
        if i != -1:
            final_masks.append((cleaned_mask == i).astype(int))

    return final_masks


def dbscan_3d(points, colors, eps, min_samples):
    """
    Apply DBSCAN in 3D on point cloud to find a dominant cluster.
    Returns:
      cluster_points, cluster_colors, threshold (cluster_size / total_size)
      If memory usage estimate too high, returns empty and threshold -1.
    """
    points = points.astype(np.float16)
    dtype = type(points[0])
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    # Memory estimation for all pairwise distances (worst-case)
    n_samples = np.shape(points)[0]
    bytes_per_entry = np.dtype(dtype).itemsize
    total_bytes = n_samples ** 2 * bytes_per_entry
    # If estimated memory exceeds ~600 GB, skip
    if total_bytes / (1024 ** 3) > 600:
        return [], [], -1

    # Run 3D DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
    labels = clustering.labels_

    # Here they take only the cluster with label 0 (assumed main cluster)
    cluster_points = points[labels == 0]
    cluster_colors = colors[labels == 0]
    threshold = cluster_points.shape[0] / points.shape[0]

    return cluster_points, cluster_colors, threshold


def remove_side_masks(masks, boxes, remove_thr, side_thr, ratio_thr):
    """
    Remove masks that lie mostly on the image borders (side artifacts).
    - remove_thr: ratio of mask pixels near edge needed to remove.
    - side_thr: pixel distance from edge to be considered 'side region'.
    - ratio_thr: overall area ratio threshold to ignore huge masks.
    """
    removed = []
    mask_shape = masks[0].size()

    for i in range(len(masks)):
        mask = masks[i]
        H, W = mask.shape

        # Coordinate grid
        y_coords = torch.arange(H).view(-1, 1).expand(H, W)
        x_coords = torch.arange(W).view(1, -1).expand(H, W)

        # Distance to image borders
        dist_top = y_coords
        dist_bottom = H - 1 - y_coords
        dist_left = x_coords
        dist_right = W - 1 - x_coords

        dist_to_edge = torch.minimum(
            torch.minimum(dist_top, dist_bottom),
            torch.minimum(dist_left, dist_right),
        )

        # Edge region mask (within side_thr pixels from any border)
        edge_mask = (dist_to_edge < side_thr).long()

        # Intersect region of interest with object mask
        final_mask = mask * edge_mask

        # Number of 'edge' pixels inside mask
        count = final_mask.sum().item()
        sum_mask = mask.sum().item()

        ratio = count / sum_mask  # fraction of mask at border
        ratio1 = sum_mask / (final_mask.size()[0] * final_mask.size()[1])  # total area ratio

        # If object mostly at the border and relatively small, remove it
        if (ratio > remove_thr) and (ratio1 < ratio_thr):
            removed.append(i)

    new_masks = []
    new_boxes = []
    # Keep only non-removed masks
    for i in range(len(masks)):
        mask = masks[i]
        box = boxes[i]
        if not (i in removed):
            new_masks.append(np.array(mask))
            new_boxes.append(np.array(box))
    return new_masks, new_boxes


def extend_images(image, boxes, masks, extension_ratio, hide_mask=False, hide_others=False):
    """
    Crop & resize regions around each bounding box, with optional masking:
      - hide_mask=True: hide the object itself (zero out its region).
      - hide_others=True: keep only the object, hide everything else.
    Then extend box by 'extension_ratio' and resize to fixed sizes.
    Returns:
      list of cropped images, list of mask coverage ratios per crop.
    """
    extended_images = []
    ratios = []
    for i in range(len(boxes)):
        new_image = copy.deepcopy(image)
        if hide_mask:
            # Keep everything except the mask (set masked areas to zero)
            new_image = (np.array(1 - masks[i])[:, :, np.newaxis]) * new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)
        if hide_others:
            # Keep only the masked object, set others to zero
            new_image = (np.array(masks[i])[:, :, np.newaxis]) * new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)

        # Center and half width/height
        center_x = ((boxes[i][2] + boxes[i][0]) / 2)
        center_y = ((boxes[i][3] + boxes[i][1]) / 2)
        width = ((boxes[i][2] - boxes[i][0]) / 2)
        height = ((boxes[i][3] - boxes[i][1]) / 2)

        # Extended crop coordinates
        new_x1 = int(center_x - extension_ratio * width)
        new_x2 = int(center_x + extension_ratio * width)
        new_y1 = int(center_y - extension_ratio * height)
        new_y2 = int(center_y + extension_ratio * height)
        x_margin = image.shape[1]
        y_margin = image.shape[0]

        # Clamp to image boundaries
        if new_x1 < 0:
            new_x1 = 0
        if new_x2 > x_margin:
            new_x2 = x_margin
        if new_y1 < 0:
            new_y1 = 0
        if new_y2 > y_margin:
            new_y2 = y_margin

        # Crop the extended region
        final_image = new_image[new_y1:new_y2, new_x1:new_x2]

        # Resize crop keeping approximate aspect ratio
        if width > height:
            final_image = cv2.resize(final_image, (1200, int(1200 / width * height)))
        else:
            final_image = cv2.resize(final_image, (int(800 / height * width), 800))

        extended_images.append(final_image)
        # Ratio of mask pixels inside the extended crop
        ratio = np.sum(masks[i]) / ((new_y2 - new_y1) * (new_x2 - new_x1))
        ratios.append(ratio.item())

    return extended_images, ratios


def grid_embedding(masks, embeddings):
    """
    Broadcast per-object embeddings to a spatial grid and
    collect embeddings at mask locations.
    - masks: [N, H, W]
    - embeddings: [N, D]
    Returns: flattened embeddings of all masked pixels.
    """
    embeddings = torch.tensor(embeddings)
    embeddings = embeddings.view(-1, 1, 1, embeddings.size()[1])
    embeddings = torch.Tensor.repeat(embeddings, (1, masks.size()[1], masks.size()[2], 1))
    # Select embedding vectors where mask > 0
    embeddings = embeddings[masks > 0]
    return embeddings


def point_cloud(depth, scale, camera_intristics, mask, camera_pose, colors, dataset):
    """
    Convert depth + intrinsics + camera pose + mask into a 3D point cloud.
    Returns:
      points: (N, 3) in world coords
      colors: (N, 3) RGB values
    Only points where mask>0 and depth>0 are kept.
    """
    mask = mask.to("cuda")
    camera_matrix = torch.tensor(camera_intristics).to("cuda")
    depth = torch.tensor(depth, dtype=torch.float32).to("cuda")
    colors = torch.tensor(colors).to("cuda")
    camera_pose = torch.tensor(camera_pose).to("cuda")

    # Pixel grid (y, x)
    y, x = torch.meshgrid(
        torch.arange(depth.size()[0]),
        torch.arange(depth.size()[1]),
        indexing="ij",
    )
    x = x.to("cuda")
    y = y.to("cuda")

    # If depth has extra channel, squeeze it
    if depth.dim() == 3:
        depth = depth[:, :, 0]

    # Convert raw depth using scale (e.g., to meters)
    depth = depth.float() / scale

    # Mask out invalid depths
    depth_mask1 = (depth > 0).long()
    mask = mask * depth_mask1
    mask = mask > 0

    # Back-project to 3D camera coordinates
    X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
    Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
    Z = depth

    # Homogeneous coordinates
    points = torch.stack(
        (X.view(-1), Y.view(-1), Z.view(-1), torch.ones_like(X.view(-1))),
        dim=-1,
    )

    # Transform to world coordinates via camera pose
    points = torch.matmul(camera_pose, points.T)
    points = points.T

    # Keep only masked points
    points = points[mask.view(-1)]
    colors = colors[mask]
    colors = colors.view(-1, 3)
    points = points.view(-1, 4)

    return points[:, :3].cpu().numpy(), colors.cpu().numpy()


def clip_image(images, model, preprocess):
    """
    Encode a list of RGB images into CLIP embeddings using a provided model+preprocess.
    Returns: list of numpy arrays (one embedding per image).
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    embeddings = []
    for i in range(len(images)):
        processed_image = preprocess(Image.fromarray(images[i])).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(processed_image)
        embeddings.append(image_features.squeeze(0).cpu().numpy())
    return embeddings


def mask_embedding(
    object_extend_s,
    object_extend_l,
    object_extend_h,
    object_extend_hide,
    object_extend_mask,
    alpha_h,
    alpha_l,
    alpha_o,
    alpha_m,
    clip_model,
    clip_path,
):
    """
    (Original combined embedding function; kept for compatibility, though
    ablations now use mask_embedding_components instead.)
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    embeddings_s = torch.tensor(clip_image(object_extend_s, clip_model, clip_path)).to(device)
    embeddings_l = torch.tensor(clip_image(object_extend_l, clip_model, clip_path)).to(device)
    embeddings_h = torch.tensor(clip_image(object_extend_h, clip_model, clip_path)).to(device)
    embeddings_mask = torch.tensor(clip_image(object_extend_mask, clip_model, clip_path)).to(device)
    embeddings_hide = torch.tensor(clip_image(object_extend_hide, clip_model, clip_path)).to(device)
    final_embedding = (
        alpha_h * embeddings_h
        + alpha_l * embeddings_l
        + embeddings_s
        - alpha_o * embeddings_hide
        + alpha_m * embeddings_mask
    )

    return final_embedding.cpu().numpy()


def mask_embedding_components(
    object_extend_s,
    object_extend_l,
    object_extend_h,
    object_extend_hide,
    object_extend_mask,
    clip_model,
    clip_path,
):
    """
    Compute the 5 base CLIP embeddings for each object (no weighting yet):
      - E_s    : 'small' crop (extension_ratio_s)
      - E_l    : 'large' crop (extension_ratio_l)
      - E_h    : 'huge' crop (extension_ratio_h)
      - E_hide : surroundings crop (object removed)
      - E_mask : object-only crop (mask)
    Returns:
      (E_s, E_l, E_h, E_hide, E_mask) as numpy arrays of shape [N, D]
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"

    E_s_list = clip_image(object_extend_s, clip_model, clip_path)
    E_l_list = clip_image(object_extend_l, clip_model, clip_path)
    E_h_list = clip_image(object_extend_h, clip_model, clip_path)
    E_hide_list = clip_image(object_extend_hide, clip_model, clip_path)
    E_mask_list = clip_image(object_extend_mask, clip_model, clip_path)

    E_s = torch.tensor(np.stack(E_s_list, axis=0), device=device)
    E_l = torch.tensor(np.stack(E_l_list, axis=0), device=device)
    E_h = torch.tensor(np.stack(E_h_list, axis=0), device=device)
    E_hide = torch.tensor(np.stack(E_hide_list, axis=0), device=device)
    E_mask = torch.tensor(np.stack(E_mask_list, axis=0), device=device)

    return (
        E_s.cpu().numpy(),
        E_l.cpu().numpy(),
        E_h.cpu().numpy(),
        E_hide.cpu().numpy(),
        E_mask.cpu().numpy(),
    )


def points_to_grid(points, resolution):
    """
    Snap 3D points to a 3D grid with cell size = resolution.
    """
    converted_points = ((points / resolution).astype(int).astype(np.float32)) * resolution
    return converted_points


def masks_detection(
    extension_ratio_hide,
    extension_ratio_s,
    extension_ratio_l,
    extension_ratio_h,
    path,
    last_idx,
    step,
    remove_iou,
    remove_side_thr,
    side_thr,
    dataset,
    area_threshold,
    ratio_thr,
    camera_intristics,
    scale,
    resolution,
    clip_model,
    preprocess,
    alpha_h,
    alpha_l,
    alpha_o,
    alpha_m,
    semantic_sam_path,
):
    """
    Main per-frame processing pipeline:
    - Load RGB, depth, pose
    - Run Semantic SAM at multiple levels to get masks
    - Filter masks (overlap, sides, tiny areas)
    - Refine masks with 2D DBSCAN
    - Build 3D point cloud for each mask, cluster with 3D DBSCAN
    - Crop extended regions, compute CLIP embeddings via mask_embedding_components
    Returns aggregated:
      my_colors, my_masks, my_points, my_embeddings,
      my_E_s, my_E_l, my_E_h, my_E_hide, my_E_mask
    """

    if dataset == "Replica":
        prefix_rgb = "results/frame"
        prefix_depth = "results/depth"
        pose_path = path
        # List all frame files starting with 'frame'
        image_list = [
            f
            for f in os.listdir(path + "results/")
            if f.startswith("frame") and os.path.isfile(os.path.join(path + "results/", f))
        ]
    elif dataset == "ScanNet":
        # For ScanNet, 'color' and 'depth' folders are used
        prefix_rgb = "color/"
        prefix_depth = "depth/"
        pose_path = path + "pose/"
        image_list = os.listdir(path + prefix_rgb)

    # Build three Semantic SAM mask generators at different levels (scales)
    mask_generator1 = SemanticSamAutomaticMaskGenerator(
        build_semantic_sam(model_type="L", ckpt=semantic_sam_path), level=[3]
    )
    mask_generator2 = SemanticSamAutomaticMaskGenerator(
        build_semantic_sam(model_type="L", ckpt=semantic_sam_path), level=[4]
    )
    mask_generator3 = SemanticSamAutomaticMaskGenerator(
        build_semantic_sam(model_type="L", ckpt=semantic_sam_path), level=[6]
    )

    my_colors = []
    my_masks = []
    my_points = []
    my_embeddings = []

    # NEW: lists for base CLIP components (per-frame collection)
    my_E_s = []
    my_E_l = []
    my_E_h = []
    my_E_hide = []
    my_E_mask = []

    # Process every 'step'-th frame up to last_idx
    for i in range(0, last_idx, step):
        print("Processing frame {}/{}".format(i, len(image_list)))
        image_path = image_list[i]

        if dataset == "Replica":
            # Extract numeric index from 'frameXXXX.jpg'
            idx_path = str(image_path[image_path.index("e") + 1 : image_path.index(".")])
        elif dataset == "ScanNet":
            # For ScanNet, file name before '.' is index
            idx_path = str(image_path[: image_path.index(".")])

        # Read RGB and depth
        rgb_i = cv2.imread(path + prefix_rgb + idx_path + ".jpg")
        if dataset == "Replica":
            depth_i = cv2.imread(path + prefix_depth + idx_path + ".png", cv2.IMREAD_UNCHANGED).astype(
                np.double
            )
        elif dataset == "ScanNet":
            depth_i = cv2.imread(path + prefix_depth + idx_path + ".png", cv2.IMREAD_UNCHANGED).astype(
                np.float32
            )

        # Load pose
        camera_pose_i = _load_pose(pose_path, int(idx_path), dataset)

        # Ensure RGB & depth have same spatial resolution
        if rgb_i.shape[:2] != depth_i.shape[:2]:
            rgb_i = cv2.resize(
                rgb_i,
                (depth_i.shape[1], depth_i.shape[0]),
                interpolation=cv2.INTER_LINEAR,
            )

        # PIL image for Semantic SAM input
        im_i = Image.open(path + prefix_rgb + idx_path + ".jpg")
        im_i = im_i.resize((depth_i.shape[1], depth_i.shape[0]))

        results_i = []
        original_image, input_image = prepare_image(image_pth=path + prefix_rgb + idx_path + ".jpg")

        # Level-3 masks
        the_masks1 = mask_generator1.generate(input_image)
        # Start with a full 1 mask for 'remaining_parts'
        remaining_parts = np.ones_like(the_masks1[0]["segmentation"].astype(int))
        remaining_parts = cv2.resize(
            remaining_parts,
            (rgb_i.shape[1], rgb_i.shape[0]),
            interpolation=cv2.INTER_NEAREST,
        )
        new_remaining_parts = copy.deepcopy(remaining_parts)

        # Filter out masks that do not cover enough unseen area
        for k in range(len(the_masks1)):
            mask = the_masks1[k]["segmentation"]
            mask = cv2.resize(
                mask.astype(int),
                (rgb_i.shape[1], rgb_i.shape[0]),
                interpolation=cv2.INTER_NEAREST,
            )
            # Condition on overlap with remaining_parts (at least 65% of mask)
            if np.sum(mask * new_remaining_parts) / np.sum(mask) > 0.65:
                results_i.append(mask)

        # Remove overlapping & tiny masks, then side masks
        the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)

        # Update remaining_parts using accepted masks
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask, dtype=np.int32))

        del the_masks1
        del the_masks_i
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)

        # Level-4 masks
        the_masks2 = mask_generator2.generate(input_image)
        for k in range(len(the_masks2)):
            mask = the_masks2[k]["segmentation"]
            mask = cv2.resize(
                mask.astype(int),
                (rgb_i.shape[1], rgb_i.shape[0]),
                interpolation=cv2.INTER_NEAREST,
            )
            if np.sum(mask * remaining_parts) / np.sum(mask) > 0.45:
                results_i.append(mask)

        the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask, dtype=np.int32))

        del the_masks2
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)

        # Level-6 masks
        the_masks3 = mask_generator3.generate(input_image)
        for k in range(len(the_masks3)):
            mask = the_masks3[k]["segmentation"]
            mask = cv2.resize(
                mask.astype(int),
                (rgb_i.shape[1], rgb_i.shape[0]),
                interpolation=cv2.INTER_NEAREST,
            )
            if np.sum(mask * remaining_parts) / np.sum(mask) > 0.35:
                results_i.append(mask)

        del mask
        del the_masks3
        gc.collect()

        # Final overlap & side filtering
        the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
        the_masks_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)

        masks_i = []
        boxes_i = []

        # For each mask, denoise with 2D DBSCAN and compute bounding boxes
        for k in range(len(the_masks_i)):
            mask = np.array(the_masks_i[k])
            mask = cv2.resize(
                mask.astype(int),
                (rgb_i.shape[1], rgb_i.shape[0]),
                interpolation=cv2.INTER_NEAREST,
            )
            new_masks = dbscan_mask_denoise(mask, 7, 80)

            for r in range(len(new_masks)):
                my_mask = new_masks[r]
                if np.sum(my_mask) > 0:
                    masks_i.append(torch.tensor(my_mask))
                    rows, cols = np.where(my_mask > 0)
                    min_row, max_row = rows.min(), rows.max()
                    min_col, max_col = cols.min(), cols.max()
                    box = [min_col, min_row, max_col, max_row]
                    boxes_i.append(torch.tensor(box))

            del new_masks
            gc.collect()

        # Remove very small masks
        masks_i, boxes_i = remove_tiny_masks(masks_i, boxes_i, area_threshold)

        new_masks_i = []
        new_boxes_i = []
        new_points_i = []
        new_colors_i = []
        thresholds = []

        # For each remaining mask, build 3D point cloud and cluster in 3D
        for k in range(len(masks_i)):
            points, colors = point_cloud(
                depth_i,
                scale,
                camera_intristics,
                masks_i[k],
                camera_pose_i,
                rgb_i,
                dataset,
            )
            # Snap points to 3D grid (quantization)
            points = points_to_grid(points, resolution)

            if points.shape[0] > 0:
                # DBSCAN in 3D
                points, colors, threshold = dbscan_3d(points, colors, 0.15, 200)
            else:
                threshold = 0.0

            # Only keep clusters that are large enough (threshold > 0.8)
            if threshold > 0.8:
                thresholds.append(threshold)
                the_mask = masks_i[k]
                rows, cols = np.where(the_mask > 0)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()

                box = [min_col, min_row, max_col, max_row]
                new_masks_i.append(masks_i[k].cpu().numpy())
                new_boxes_i.append(np.array(box))
                new_points_i.append(points)
                new_colors_i.append(colors)

            del points
            del colors
            gc.collect()

        if len(new_boxes_i) == 0:
            # no valid objects in this frame
            continue

        # Build different visual crops (small/large/hide/etc.) to feed CLIP
        extended_s_i, ratio_s_i = extend_images(
            rgb_i, new_boxes_i, new_masks_i, extension_ratio_s
        )
        extended_l_i, ratio_l_i = extend_images(
            rgb_i, new_boxes_i, new_masks_i, extension_ratio_l
        )
        extended_h_i, ratio_h_i = extend_images(
            rgb_i, new_boxes_i, new_masks_i, extension_ratio_h
        )
        hide_i, ratios_i = extend_images(
            rgb_i, new_boxes_i, new_masks_i, extension_ratio_hide, True
        )
        extended_mask_i, ratio_m_i = extend_images(
            rgb_i, new_boxes_i, new_masks_i, 1.0, False, True
        )

        # === NEW: compute base components first ===
        E_s_i, E_l_i, E_h_i, E_hide_i, E_mask_i = mask_embedding_components(
            extended_s_i,
            extended_l_i,
            extended_h_i,
            hide_i,
            extended_mask_i,
            clip_model,
            preprocess,
        )

        # Combine them ONCE using default alphas
        E_s_tensor = torch.tensor(E_s_i)
        E_l_tensor = torch.tensor(E_l_i)
        E_h_tensor = torch.tensor(E_h_i)
        E_hide_tensor = torch.tensor(E_hide_i)
        E_mask_tensor = torch.tensor(E_mask_i)

        combined_i = (
            alpha_h * E_h_tensor
            + alpha_l * E_l_tensor
            + E_s_tensor
            - alpha_o * E_hide_tensor
            + alpha_m * E_mask_tensor
        ).cpu().numpy()  # [N_obj, D]

        # Store per-frame lists
        my_colors.append(new_colors_i)
        my_masks.append(new_masks_i)
        my_points.append(new_points_i)
        my_embeddings.append(combined_i)

        my_E_s.append(E_s_i)
        my_E_l.append(E_l_i)
        my_E_h.append(E_h_i)
        my_E_hide.append(E_hide_i)
        my_E_mask.append(E_mask_i)

        # Cleanup
        del combined_i
        del E_s_i, E_l_i, E_h_i, E_hide_i, E_mask_i
        del extended_s_i, extended_l_i, extended_h_i, hide_i, extended_mask_i
        gc.collect()

    # Flatten per-frame lists into single lists
    final_masks = []
    final_points = []
    final_embeddings = []
    final_colors = []
    final_E_s = []
    final_E_l = []
    final_E_h = []
    final_E_hide = []
    final_E_mask = []

    num_frames = len(my_masks)
    for i in range(num_frames):
        num_objs_i = len(my_masks[i])
        for j in range(num_objs_i):
            final_masks.append(my_masks[i][j])
            final_points.append(my_points[i][j])
            final_embeddings.append(my_embeddings[i][j])
            final_colors.append(my_colors[i][j])

            final_E_s.append(my_E_s[i][j])
            final_E_l.append(my_E_l[i][j])
            final_E_h.append(my_E_h[i][j])
            final_E_hide.append(my_E_hide[i][j])
            final_E_mask.append(my_E_mask[i][j])

    return (
        final_colors,
        final_masks,
        final_points,
        final_embeddings,
        final_E_s,
        final_E_l,
        final_E_h,
        final_E_hide,
        final_E_mask,
    )


def grid_indices(points, params, res):
    """
    Convert 3D points into voxel grid indices.
    NOTE: this function assumes a global 'device' & 'resolution' variable.
    """
    points = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(torch.tensor(points))
    indices[:, 0] = ((points[:, 0] - x_min) / res).long()
    indices[:, 1] = ((points[:, 1] - y_min) / res).long()
    indices[:, 2] = ((points[:, 2] - z_min) / res).long()
    return indices.long()


def voxelize_batch(points, X, Y, Z, params):
    """
    Convert a point cloud into a (X,Y,Z) voxel grid.
    Output:
      voxel: binary occupancy grid.
    NOTE: uses global 'device' and 'resolution'.
    """
    voxel = torch.zeros((X, Y, Z), device=device)
    indices = grid_indices(points, params, resolution)
    voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1

    return voxel


def geometry_overlap(points, resolution, overlap_thr1, overlap_thr2):
    """
    Compute pairwise geometric overlap between all point clouds.
    - points: list of point clouds (each (N_i, 3))
    - resolution: voxel size
    - overlap_thr1: minimum overlap ratio per object
    - overlap_thr2: max allowed difference between overlaps_i and overlaps_j
    Returns:
      adjacency: NxN matrix, adjacency[i,j]=1 if two objects overlap enough.
    """
    num_points = len(points)
    adjacency = np.zeros((num_points, num_points))
    x_min = np.zeros(num_points)
    y_min = np.zeros(num_points)
    z_min = np.zeros(num_points)
    x_max = np.zeros(num_points)
    y_max = np.zeros(num_points)
    z_max = np.zeros(num_points)

    # Compute bounding boxes for each point cloud
    for i in range(num_points):
        x_min[i] = np.min(points[i][:, 0])
        y_min[i] = np.min(points[i][:, 1])
        z_min[i] = np.min(points[i][:, 2])
        x_max[i] = np.max(points[i][:, 0])
        y_max[i] = np.max(points[i][:, 1])
        z_max[i] = np.max(points[i][:, 2])

    # Compare all pairs
    for i in range(0, num_points):
        print(i, " point clouds processed from ", num_points, "!")

        for j in range(0, num_points):
            # Quickly discard non-overlapping AABBs
            if not (
                (x_min[i] > x_max[j] or x_max[i] < x_min[j])
                or (y_min[i] > y_max[j] or y_max[j] < y_min[j])
                or (z_min[i] > z_max[j] or z_max[i] < z_min[j])
            ):
                # Build union bounding box with small padding
                my_x_min = np.min([x_min[i], x_min[j]]) - 0.2
                my_y_min = np.min([y_min[i], y_min[j]]) - 0.2
                my_z_min = np.min([z_min[i], z_min[j]]) - 0.2
                my_x_max = np.max([x_max[i], x_max[j]]) + 0.2
                my_y_max = np.max([y_max[i], y_max[j]]) + 0.2
                my_z_max = np.max([z_max[i], z_max[j]]) + 0.2
                X = int((my_x_max - my_x_min) / resolution)
                Y = int((my_y_max - my_y_min) / resolution)
                Z = int((my_z_max - my_z_min) / resolution)
                params = (my_x_min, my_y_min, my_z_min, my_x_max, my_y_max, my_z_max)

                # Voxelize both point clouds into same grid
                v_i = voxelize_batch(points[i], X, Y, Z, params)
                v_j = voxelize_batch(points[j], X, Y, Z, params)

                overlaps = torch.sum(v_i * v_j)
                overlaps_i = overlaps / torch.sum(v_i)
                overlaps_j = overlaps / torch.sum(v_j)

                # If both overlap ratios are high and similar -> mark as adjacent
                if (
                    overlaps_i > overlap_thr1
                    and overlaps_j > overlap_thr1
                    and (torch.abs(overlaps_j - overlaps_i) < overlap_thr2)
                ):
                    adjacency[i, j] = 1

                del v_j, overlaps_i, overlaps_j, overlaps
                torch.cuda.empty_cache()

    return adjacency


def merge_points(adj_matrix):
    """
    Given adjacency matrix, find connected components (merged object clusters).
    """
    G = nx.from_numpy_array(adj_matrix)
    components = list(nx.connected_components(G))
    return components


def object_embeddings(
    components,
    points,
    embeddings,
    colors,
    E_s,
    E_l,
    E_h,
    E_hide,
    E_mask,
):
    """
    Aggregate points, colors, and *all* embeddings over connected components.
    Inputs:
      - components: list of sets of object indices to be merged
      - points:      list of (Ni, 3)
      - embeddings:  list of (D,) final combined embeddings
      - colors:      list of (Ni, 3)
      - E_*:         list of (D,) base embeddings per object
    Returns:
      final_colors, final_points, final_embeddings, count,
      agg_E_s, agg_E_l, agg_E_h, agg_E_hide, agg_E_mask
    """
    final_points = []
    final_colors = []
    final_embeddings = []
    final_E_s = []
    final_E_l = []
    final_E_h = []
    final_E_hide = []
    final_E_mask = []
    count = []

    for comp in components:
        new_points = []
        new_colors = []
        new_embeddings = []
        new_E_s = []
        new_E_l = []
        new_E_h = []
        new_E_hide = []
        new_E_mask = []

        for j in comp:
            new_points.append(points[j])
            new_colors.append(colors[j])
            new_embeddings.append(embeddings[j])
            new_E_s.append(E_s[j])
            new_E_l.append(E_l[j])
            new_E_h.append(E_h[j])
            new_E_hide.append(E_hide[j])
            new_E_mask.append(E_mask[j])

        final_colors.append(np.vstack(new_colors))
        final_points.append(np.vstack(new_points))
        final_embeddings.append(np.mean(np.stack(new_embeddings, axis=0), axis=0))

        final_E_s.append(np.mean(np.stack(new_E_s, axis=0), axis=0))
        final_E_l.append(np.mean(np.stack(new_E_l, axis=0), axis=0))
        final_E_h.append(np.mean(np.stack(new_E_h, axis=0), axis=0))
        final_E_hide.append(np.mean(np.stack(new_E_hide, axis=0), axis=0))
        final_E_mask.append(np.mean(np.stack(new_E_mask, axis=0), axis=0))

        count.append(len(new_embeddings))

    return (
        final_colors,
        final_points,
        final_embeddings,
        count,
        final_E_s,
        final_E_l,
        final_E_h,
        final_E_hide,
        final_E_mask,
    )


def objects_class(embeddings, classes):
    """
    Given embeddings (e.g. similarity scores) and a list of classes,
    choose the class with the maximum score for each object.
    """
    final_classes = []
    for i in range(len(embeddings)):
        final_index = np.argsort(embeddings[i])
        final_class_obj = np.array(classes)[final_index]
        final_classes.append(final_class_obj[-1])
    return final_classes


def main():
    """
    Main entry point:
    - parse args
    - load config
    - run mask detection and 3D grouping
    - save maps: point -> object_id and object_id -> embedding (with components)
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("--scene", type=str)
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--clip_model", type=str, default="EVA02-L-14-336")
    parser.add_argument("--path", type=str, default="")
    parser.add_argument("--dataset", type=str)
    args = parser.parse_args()

    path = args.path
    dataset = args.dataset
    scene = args.scene
    dataset_path = args.dataset_path
    clip_name = args.clip_model

    # CLIP model weights and Semantic SAM weights
    clip_path = path + "models/open_clip_pytorch_model.bin"
    semantic_sam_path = "models/swinl_only_sam_many2many.pth"

    # Create CLIP model & transforms
    clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_name, pretrained=None)
    clip_model.load_state_dict(torch.load(clip_path), strict=False)

    # Load config file and last_idx (#frames)
    if dataset == "Replica":
        last_idx = 2000
        with open(path + "core_configs/config_Replica.yaml", "r") as f:
            config = yaml.safe_load(f)
    else:
        last_idx = len(os.listdir(dataset_path + scene + "/color/"))
        with open(path + "core_configs/config_ScanNet.yaml", "r") as f:
            config = yaml.safe_load(f)

    global resolution
    # Read important thresholds and parameters from YAML
    geometery_overlap_thr1 = config["geometery_overlap_thr1"]
    geometery_overlap_thr2 = config["geometery_overlap_thr2"]
    resolution = config["resolution"]
    remove_thr = config["remove_thr"]
    side_thr = config["side_thr"]
    iou_remove_thr = config["iou_remove_thr"]
    ratio_thr = config["ratio_thr"]
    extension_ratio_hide = config["extension_ratio_hide"]
    extension_ratio_s = config["extension_ratio_s"]
    extension_ratio_l = config["extension_ratio_l"]
    extension_ratio_h = config["extension_ratio_h"]
    area_threshold = config["area_threshold"]
    alpha_h = config["alpha_h"]
    alpha_l = config["alpha_l"]
    alpha_o = config["alpha_o"]
    alpha_m = config["alpha_m"]

    # Path to the scene data
    scene_path = dataset_path + scene + "/"

    # Load intrinsics and depth scale
    if dataset == "Replica":
        camera_intristics, scale = _load_depth_intrinsics(
            path=dataset_path + "/cam_params.json",
            dataset=dataset,
        )
    else:
        camera_intristics, scale = _load_depth_intrinsics(
            path=dataset_path + scene + "/intrinsic/intrinsic_depth.txt",
            dataset=dataset,
        )

    # Run the full detection+embedding pipeline (now returns base components too)
    (
        colors,
        masks,
        points,
        embeddings,
        E_s,
        E_l,
        E_h,
        E_hide,
        E_mask,
    ) = masks_detection(
        extension_ratio_hide=extension_ratio_hide,
        extension_ratio_s=extension_ratio_s,
        extension_ratio_l=extension_ratio_l,
        extension_ratio_h=extension_ratio_h,
        path=scene_path,
        last_idx=last_idx,
        step=15,
        remove_iou=iou_remove_thr,
        remove_side_thr=remove_thr,
        side_thr=side_thr,
        dataset=dataset,
        area_threshold=area_threshold,
        ratio_thr=ratio_thr,
        camera_intristics=camera_intristics,
        scale=scale,
        resolution=resolution,
        clip_model=clip_model,
        preprocess=preprocess,
        alpha_h=alpha_h,
        alpha_l=alpha_l,
        alpha_o=alpha_o,
        alpha_m=alpha_m,
        semantic_sam_path=semantic_sam_path,
    )

    # Compute adjacency matrix based on 3D geometric overlap
    adjacency = geometry_overlap(points, resolution, geometery_overlap_thr1, geometery_overlap_thr2)

    # Merge objects into connected components
    components = merge_points(adjacency)

    # Aggregate embeddings & points for merged objects (also base components)
    (
        new_colors,
        new_points,
        new_embeddings,
        count,
        agg_E_s,
        agg_E_l,
        agg_E_h,
        agg_E_hide,
        agg_E_mask,
    ) = object_embeddings(
        components,
        points,
        embeddings,
        colors,
        E_s,
        E_l,
        E_h,
        E_hide,
        E_mask,
    )

    # DataFrame mapping 3D points -> object ids
    df_points_to_ids = pd.DataFrame(columns=["x", "y", "z", "Object id"])

    # Dictionary mapping object id -> embedding and count
    df_ids_to_embeddings = {}
    for i in range(len(new_points)):
        pts = np.unique(new_points[i], axis=0)
        temp = pd.DataFrame(columns=["x", "y", "z", "Object id"])
        temp["x"] = pts[:, 0]
        temp["y"] = pts[:, 1]
        temp["z"] = pts[:, 2]
        temp["Object id"] = i

        df_ids_to_embeddings[i] = {
            "embedding_default": list(new_embeddings[i].astype(float)),
            "E_s": list(agg_E_s[i].astype(float)),
            "E_l": list(agg_E_l[i].astype(float)),
            "E_h": list(agg_E_h[i].astype(float)),
            "E_hide": list(agg_E_hide[i].astype(float)),
            "E_mask": list(agg_E_mask[i].astype(float)),
            "count": count[i],
        }
        df_points_to_ids = pd.concat([df_points_to_ids, temp], ignore_index=True)

    # Save embeddings and point-to-id mapping
    os.makedirs(path + "embeddings/", exist_ok=True)
    with open(path + "embeddings/" + scene + "_ids_to_embeddings_ctx_scannet.json", "w") as json_file:
        json.dump(df_ids_to_embeddings, json_file, indent=4)

    df_points_to_ids.to_csv(path + "embeddings/" + scene + "_points_to_ids_ctx_scannet.csv", index=False)


if __name__ == "__main__":
    main()
