import os
import numpy as np
import yaml
from PIL import Image

import cv2
import argparse
import json
import torch
import gc
import copy
import pandas as pd
import open_clip
from sklearn.cluster import DBSCAN
import networkx as nx
import psutil

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# Global device
device = "cuda" if torch.cuda.is_available() else "cpu"


def _load_pose(path, idx, dataset):
    """
    Load camera pose (4x4 transformation matrix) for a given frame index.
    Supports Replica and ScanNet datasets.
    """
    if dataset == 'Replica':
        # For Replica, all poses are in a single 'traj.txt' file
        path = os.path.join(path, "traj.txt")
        with open(path, "r") as file:
            lines = file.readlines()
            if 0 <= idx < len(lines):
                line = lines[idx]
                # each line has 16 values (4x4 matrix flattened row-wise)
                values = [float(val) for val in line.split()]
                transformation_matrix = np.array(values).reshape((4, 4))
                return transformation_matrix
    elif dataset == 'ScanNet':
        # For ScanNet, each pose is stored in a separate txt file: "<idx>.txt"
        path = os.path.join(path, str(idx) + '.txt')
        transformation_matrix = np.loadtxt(path).reshape(4, 4)
        return transformation_matrix


def _load_depth_intrinsics(path, dataset):
    """
    Load depth camera intrinsics and scale factor for depth values.
    Different format for Replica vs ScanNet.
    """
    if dataset == 'Replica':
        # Replica intrinsics and scale come from a JSON file
        with open(path, "r") as file:
            data = json.load(file)
            camera_params = data.get("camera")
            if camera_params:
                fx = camera_params.get("fx")
                fy = camera_params.get("fy")
                cx = camera_params.get("cx")
                cy = camera_params.get("cy")
                scale = camera_params.get("scale")
                # Camera intrinsic matrix K
                K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
                K = np.array(K)
                return K, scale
    elif dataset == 'ScanNet':
        # ScanNet intrinsics: already in matrix form
        intrinsic_depth = np.loadtxt(path)
        # Depth values are typically in millimeters -> convert to meters by scale=1000
        scale = 1000.0
        return intrinsic_depth, scale


def overlap(mask_i, mask_j, iom_threshold):
    """
    Compute 'intersection over minimum' (IoM) between mask_i and mask_j.
    Return:
      0  -> no significant overlap
      1  -> mask_i mostly inside mask_j
      2  -> mask_j mostly inside mask_i
    """
    intersection = np.sum(np.multiply(mask_i, mask_j))
    sum_i = np.sum(mask_i)
    sum_j = np.sum(mask_j)
    if sum_i > sum_j:
        # IoM denominator = smaller mask
        iom = intersection / sum_j
        if iom > iom_threshold:
            return 2  # j is inside i
        else:
            return 0
    else:
        iom = intersection / sum_i
        if iom > iom_threshold:
            return 1  # i is inside j
        else:
            return 0


def remove_overlapped_masks(results, iom_threshold, area_threshold):
    """
    Remove overlapping and tiny masks.
    - results: list of 2D binary masks (numpy arrays).
    - iom_threshold: IoM threshold for overlap pruning.
    - area_threshold: minimum area ratio to keep a mask.
    Returns:
      torch.Tensor of kept masks [N, H, W]
      torch.Tensor of corresponding boxes [N, 4] (xmin, ymin, xmax, ymax)
    """
    mask_shape = results[0].shape
    boxes = torch.zeros(len(results), 4)
    new_masks = torch.zeros(len(results), mask_shape[0], mask_shape[1])
    removed = []

    for i in range(len(results)):
        remove_i = []
        remove_mask_i = []
        mask_i = results[i].astype(int)

        # Area ratio relative to full image
        area_ratio = np.sum(mask_i) / (np.shape(mask_i)[0] * np.shape(mask_i)[1])

        if area_ratio < area_threshold:
            # mask too small -> remove
            removed.append(i)
        else:
            # Check overlap with all other masks
            for j in range(len(results)):
                if j != i:
                    mask_j = results[j].astype(int)
                    index = overlap(mask_i, mask_j, iom_threshold=iom_threshold)
                    # If j is inside i, we remove j from i
                    if index == 2:
                        if j not in remove_i:
                            remove_i.append(j)
                            remove_mask_i.append(mask_j)
            # Subtract overlapped masks from mask_i
            for mask in remove_mask_i:
                mask_i = mask_i * (1 - mask)
            if np.sum(mask_i) > 0:
                new_masks[i] = torch.tensor(mask_i)

                # Compute bounding box from remaining mask_i
                rows, cols = np.where(mask_i)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()
                box = [min_col, min_row, max_col, max_row]
                boxes[i] = torch.tensor(box)

    the_masks = []
    the_boxes = []
    # Collect only non-removed and non-empty masks
    for i in range(new_masks.size()[0]):
        mask = new_masks[i]
        box = boxes[i]

        if (i not in removed) and torch.sum(mask) > 0:
            the_masks.append(mask)
            the_boxes.append(box)

    if len(the_masks) == 0:
        return torch.empty(0, *mask_shape), torch.empty(0, 4)

    return torch.stack(the_masks), torch.stack(the_boxes)


def remove_tiny_masks(masks, boxes, area_threshold):
    """
    Remove masks whose area ratio is below area_threshold.
    Input:
      masks: list of (H, W) tensors
      boxes: list of [4] tensors
    """
    the_masks = []
    the_boxes = []
    for i in range(len(masks)):
        mask = masks[i]
        box = boxes[i]
        area_ratio = torch.sum(mask) / (mask.size()[0] * mask.size()[1])

        if area_ratio > area_threshold:
            the_masks.append(mask)
            the_boxes.append(box)

    return the_masks, the_boxes


def dbscan_mask_denoise(mask, eps, min_samples):
    """
    Apply 2D DBSCAN on mask pixels to remove noise and split into clusters.
    Returns list of binary masks (one per cluster).
    """
    # Coordinates of non-zero pixels
    coords = np.column_stack(np.nonzero(mask))  # shape: (N, 2)

    if coords.shape[0] == 0:
        # no pixels -> return original mask (or empty)
        return []

    # Cluster using DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(coords)
    labels = clustering.labels_  # -1 indicates noise

    # Build a label image initialized with -1
    cleaned_mask = -1 * np.ones_like(mask)
    for cluster_id in np.unique(labels):
        if cluster_id == -1:
            continue  # skip noise cluster
        cluster_points = coords[labels == cluster_id]
        cleaned_mask[cluster_points[:, 0], cluster_points[:, 1]] = cluster_id

    final_masks = []
    # Convert each cluster to an individual binary mask
    for i in np.unique(cleaned_mask):
        if i != -1:
            final_masks.append((cleaned_mask == i).astype(int))

    return final_masks


def dbscan_3d(points, colors, eps, min_samples):
    """
    Apply DBSCAN in 3D on point cloud to find a dominant cluster.
    Returns:
      cluster_points, cluster_colors, threshold (cluster_size / total_size)
      If memory usage estimate too high, returns empty and threshold -1.
    """
    points = points.astype(np.float16)
    dtype = type(points[0])
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    # Memory estimation for all pairwise distances (worst-case)
    n_samples = np.shape(points)[0]
    bytes_per_entry = np.dtype(dtype).itemsize
    total_bytes = n_samples ** 2 * bytes_per_entry
    # If estimated memory exceeds ~600 GB, skip
    if (total_bytes / (1024 ** 3) > 600):
        return [], [], -1

    # Run 3D DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
    labels = clustering.labels_

    # Here they take only the cluster with label 0 (assumed main cluster)
    cluster_points = points[labels == 0]
    cluster_colors = colors[labels == 0]
    if points.shape[0] == 0:
        threshold = 0.0
    else:
        threshold = cluster_points.shape[0] / points.shape[0]

    return cluster_points, cluster_colors, threshold


def remove_side_masks(masks, boxes, remove_thr, side_thr, ratio_thr):
    """
    Remove masks that lie mostly on the image borders (side artifacts).
    - remove_thr: ratio of mask pixels near edge needed to remove.
    - side_thr: pixel distance from edge to be considered 'side region'.
    - ratio_thr: overall area ratio threshold to ignore huge masks.
    """
    removed = []

    for i in range(len(masks)):
        mask = masks[i]
        H, W = mask.shape

        # Coordinate grid
        y_coords = torch.arange(H).view(-1, 1).expand(H, W)
        x_coords = torch.arange(W).view(1, -1).expand(H, W)

        # Distance to image borders
        dist_top = y_coords
        dist_bottom = H - 1 - y_coords
        dist_left = x_coords
        dist_right = W - 1 - x_coords

        dist_to_edge = torch.minimum(torch.minimum(dist_top, dist_bottom),
                                     torch.minimum(dist_left, dist_right))

        # Edge region mask (within side_thr pixels from any border)
        edge_mask = (dist_to_edge < side_thr).long()

        # Intersect region of interest with object mask
        final_mask = mask * edge_mask

        # Number of 'edge' pixels inside mask
        count = final_mask.sum().item()
        sum_mask = mask.sum().item()

        if sum_mask == 0:
            removed.append(i)
            continue

        ratio = count / sum_mask  # fraction of mask at border
        ratio1 = sum_mask / (final_mask.size()[0] * final_mask.size()[1])  # total area ratio

        # If object mostly at the border and relatively small, remove it
        if (ratio > remove_thr) and (ratio1 < ratio_thr):
            removed.append(i)

    new_masks = []
    new_boxes = []
    for i in range(len(masks)):
        if i not in removed:
            new_masks.append(np.array(masks[i]))
            new_boxes.append(np.array(boxes[i]))

    return new_masks, new_boxes


def extend_images(image, boxes, masks, extension_ratio, hide_mask=False, hide_others=False):
    """
    Crop & resize regions around each bounding box, with optional masking:
      - hide_mask=True: hide the object itself (zero out its region).
      - hide_others=True: keep only the object, hide everything else.
    Then extend box by 'extension_ratio' and resize to fixed sizes.
    Returns:
      list of cropped images, list of mask coverage ratios per crop.
    """
    extended_images = []
    ratios = []
    for i in range(len(boxes)):

        new_image = copy.deepcopy(image)
        if hide_mask:
            # Keep everything except the mask (set masked areas to zero)
            new_image = (np.array(1 - masks[i])[:, :, np.newaxis]) * new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)
        if hide_others:
            # Keep only the masked object, set others to zero
            new_image = (np.array(masks[i])[:, :, np.newaxis]) * new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)

        # Center and half width/height
        center_x = ((boxes[i][2] + boxes[i][0]) / 2)
        center_y = ((boxes[i][3] + boxes[i][1]) / 2)
        width = ((boxes[i][2] - boxes[i][0]) / 2)
        height = ((boxes[i][3] - boxes[i][1]) / 2)

        # Extended crop coordinates
        new_x1 = int(center_x - extension_ratio * width)
        new_x2 = int(center_x + extension_ratio * width)
        new_y1 = int(center_y - extension_ratio * height)
        new_y2 = int(center_y + extension_ratio * height)
        x_margin = image.shape[1]
        y_margin = image.shape[0]

        # Clamp to image boundaries
        new_x1 = max(new_x1, 0)
        new_x2 = min(new_x2, x_margin)
        new_y1 = max(new_y1, 0)
        new_y2 = min(new_y2, y_margin)

        # Crop the extended region
        final_image = new_image[new_y1:new_y2, new_x1:new_x2]

        # Resize crop keeping approximate aspect ratio
        if width > height:
            final_image = cv2.resize(final_image, (1200, int(1200 / width * height)))
        else:
            final_image = cv2.resize(final_image, (int(800 / height * width), 800))

        extended_images.append(final_image)
        # Ratio of mask pixels inside the extended crop
        ratio = np.sum(masks[i]) / ((new_y2 - new_y1) * (new_x2 - new_x1) + 1e-6)
        ratios.append(float(ratio))

    return extended_images, ratios


def grid_embedding(masks, embeddings):
    """
    Broadcast per-object embeddings to a spatial grid and
    collect embeddings at mask locations.
    - masks: [N, H, W]
    - embeddings: [N, D]
    Returns: flattened embeddings of all masked pixels.
    """
    embeddings = torch.tensor(embeddings)
    embeddings = embeddings.view(-1, 1, 1, embeddings.size()[1])
    embeddings = torch.Tensor.repeat(embeddings, (1, masks.size()[1], masks.size()[2], 1))
    # Select embedding vectors where mask > 0
    embeddings = embeddings[masks > 0]
    return embeddings


def point_cloud(depth, scale, camera_intristics, mask, camera_pose, colors, dataset):
    """
    Convert depth + intrinsics + camera pose + mask into a 3D point cloud.
    Returns:
      points: (N, 3) in world coords
      colors: (N, 3) RGB values
    Only points where mask>0 and depth>0 are kept.
    """
    cam_device = 'cuda' if torch.cuda.is_available() else 'cpu'

    mask = mask.to(cam_device)
    camera_matrix = torch.tensor(camera_intristics).to(cam_device)
    depth = torch.tensor(depth, dtype=torch.float32).to(cam_device)
    colors = torch.tensor(colors).to(cam_device)
    camera_pose = torch.tensor(camera_pose).to(cam_device)

    # Pixel grid (y, x)
    y, x = torch.meshgrid(
        torch.arange(depth.size()[0], device=cam_device),
        torch.arange(depth.size()[1], device=cam_device),
        indexing='ij'
    )

    # If depth has extra channel, squeeze it
    if depth.dim() == 3:
        depth = depth[:, :, 0]

    # Convert raw depth using scale (e.g., to meters)
    depth = depth.float() / scale

    # Mask out invalid depths
    depth_mask1 = (depth > 0).long()
    mask = mask * depth_mask1
    mask = mask > 0

    # Back-project to 3D camera coordinates
    X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
    Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
    Z = depth

    # Homogeneous coordinates
    points = torch.stack(
        (X.view(-1), Y.view(-1), Z.view(-1), torch.ones_like(X.view(-1))),
        dim=-1
    )

    # Transform to world coordinates via camera pose
    points = torch.matmul(camera_pose, points.T)
    points = points.T

    # Keep only masked points
    points = points[mask.view(-1)]
    colors = colors[mask]
    colors = colors.view(-1, 3)
    points = points.view(-1, 4)

    return points[:, :3].cpu().numpy(), colors.cpu().numpy()


def clip_image(images, model, preprocess):
    """
    Encode a list of RGB images into CLIP embeddings using a provided model+preprocess.
    Returns: list of numpy arrays (one embedding per image).
    """
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(dev)
    embeddings = []
    for i in range(len(images)):
        processed_image = preprocess(Image.fromarray(images[i])).unsqueeze(0).to(dev)
        with torch.no_grad():
            image_features = model.encode_image(processed_image)
        embeddings.append(image_features.squeeze(0).cpu().numpy())
    return embeddings


def masks_detection(
        extension_ratio_hide,
        extension_ratio_s,
        extension_ratio_l,
        extension_ratio_h,
        path,
        last_idx,
        step,
        remove_iou,
        remove_side_thr,
        side_thr,
        dataset,
        area_threshold,
        ratio_thr,
        camera_intristics,
        scale,
        resolution,
        clip_model,
        preprocess,
        alpha_h,
        alpha_l,
        alpha_o,
        alpha_m,
        sam_checkpoint
):
    """
    Main per-frame processing pipeline:
    - Load RGB, depth, pose
    - Run SAM to get masks
    - Filter masks (overlap, sides, tiny areas)
    - Refine masks with 2D DBSCAN
    - Build 3D point cloud for each mask, cluster with 3D DBSCAN
    - Crop extended regions, compute CLIP embeddings via mask_embedding()
    Returns aggregated:
      my_colors, my_masks, my_points, my_embeddings
    """

    # Dataset-specific prefixes
    if dataset == 'Replica':
        prefix_rgb = 'results/frame'
        prefix_depth = 'results/depth'
        pose_path = path
        # List all frame files starting with 'frame'
        image_list = [
            f for f in os.listdir(os.path.join(path, 'results'))
            if f.startswith('frame') and os.path.isfile(os.path.join(path, 'results', f))
        ]
    elif dataset == 'ScanNet':
        prefix_rgb = 'color/'
        prefix_depth = 'depth/'
        pose_path = os.path.join(path, 'pose/')
        image_list = os.listdir(os.path.join(path, prefix_rgb))
    else:
        raise ValueError("Unknown dataset type")

    # Build SAM once
    sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
    sam.to(device)
    mask_generator = SamAutomaticMaskGenerator(
        sam,
        points_per_side=32,
        pred_iou_thresh=0.9,
        stability_score_thresh=0.92,
        crop_n_layers=0,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100
    )

    my_colors = []
    my_masks = []
    my_points = []
    my_embeddings = []

    # Process every 'step'-th frame up to last_idx
    for i in range(0, min(last_idx, len(image_list)), step):

        print('Processing frame {}/{}'.format(i, len(image_list)))
        image_path = image_list[i]

        if dataset == 'Replica':
            # Extract numeric index from 'frameXXXX.jpg'
            idx_path = str(image_path[image_path.index('e') + 1:image_path.index('.')])
        elif dataset == 'ScanNet':
            # For ScanNet, file name before '.' is index
            idx_path = str(image_path[:image_path.index('.')])

        # Read RGB and depth
        if dataset == 'Replica':
            rgb_i = cv2.imread(os.path.join(path, prefix_rgb + idx_path + '.jpg'))
            depth_i = cv2.imread(os.path.join(path, prefix_depth + idx_path + '.png'),
                                 cv2.IMREAD_UNCHANGED).astype(np.double)
        else:
            rgb_i = cv2.imread(os.path.join(path, prefix_rgb, idx_path + '.jpg'))
            depth_i = cv2.imread(os.path.join(path, prefix_depth, idx_path + '.png'),
                                 cv2.IMREAD_UNCHANGED).astype(np.float32)

        # Load pose
        camera_pose_i = _load_pose(pose_path, int(idx_path), dataset)

        # Ensure RGB & depth have same spatial resolution
        if rgb_i.shape[:2] != depth_i.shape[:2]:
            rgb_i = cv2.resize(rgb_i, (depth_i.shape[1], depth_i.shape[0]), interpolation=cv2.INTER_LINEAR)

        # Run SAM on RGB (convert BGR -> RGB)
        sam_input = cv2.cvtColor(rgb_i, cv2.COLOR_BGR2RGB)
        sam_masks = mask_generator.generate(sam_input)  # list of dicts

        results_i = []
        for m in sam_masks:
            mask = m["segmentation"].astype(np.uint8)
            if mask.shape[:2] != rgb_i.shape[:2]:
                mask = cv2.resize(mask, (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
            results_i.append(mask)

        if len(results_i) == 0:
            continue

        # Remove overlapping & tiny masks, then side masks
        the_masks_i, boxes_i = remove_overlapped_masks(results_i, remove_iou, area_threshold)
        if the_masks_i.numel() == 0:
            continue
        results_i, boxes_i = remove_side_masks(the_masks_i, boxes_i, remove_side_thr, side_thr, ratio_thr)
        if len(results_i) == 0:
            continue

        masks_i = []
        boxes_i_list = []

        # For each mask, denoise with 2D DBSCAN and compute bounding boxes
        for k in range(len(results_i)):
            mask = np.array(results_i[k])
            mask = cv2.resize(mask.astype(int), (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
            new_masks = dbscan_mask_denoise(mask, 7, 80)

            for r in range(len(new_masks)):
                my_mask = new_masks[r]
                if np.sum(my_mask) > 0:
                    masks_i.append(torch.tensor(my_mask))
                    rows, cols = np.where(my_mask > 0)
                    min_row, max_row = rows.min(), rows.max()
                    min_col, max_col = cols.min(), cols.max()
                    box = [min_col, min_row, max_col, max_row]
                    boxes_i_list.append(torch.tensor(box))

            del new_masks
            gc.collect()

        if len(masks_i) == 0:
            continue

        # Remove very small masks
        masks_i, boxes_i_list = remove_tiny_masks(masks_i, boxes_i_list, area_threshold)
        if len(masks_i) == 0:
            continue

        new_masks_i = []
        new_boxes_i = []
        new_points_i = []
        new_colors_i = []

        # For each remaining mask, build 3D point cloud and cluster in 3D
        for k in range(len(masks_i)):
            points, colors = point_cloud(depth_i, scale, camera_intristics, masks_i[k], camera_pose_i, rgb_i, dataset)
            # Snap points to 3D grid (quantization)
            points = points_to_grid(points, resolution)

            if points.shape[0] > 0:
                # DBSCAN in 3D
                points_db, colors_db, threshold = dbscan_3d(points, colors, 0.15, 200)
            else:
                points_db, colors_db, threshold = points, colors, 0.0

            # Only keep clusters that are large enough (threshold > 0.8)
            if threshold > 0.8:
                the_mask = masks_i[k]
                rows, cols = np.where(the_mask.numpy() > 0)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()

                box = [min_col, min_row, max_col, max_row]
                new_masks_i.append(masks_i[k].cpu().numpy())
                new_boxes_i.append(np.array(box))
                new_points_i.append(points_db)
                new_colors_i.append(colors_db)

            del points_db
            del colors_db
            gc.collect()

        if len(new_masks_i) == 0:
            continue

        # Build different visual crops (small/large/hide/etc.) to feed CLIP
        extended_s_i, _ = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_s)
        extended_l_i, _ = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_l)
        extended_h_i, _ = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_h)
        hide_i, _ = extend_images(rgb_i, new_boxes_i, new_masks_i, extension_ratio_hide, True)
        extended_mask_i, _ = extend_images(rgb_i, new_boxes_i, new_masks_i, 1.0, False, True)

        # Get combined embedding per object
        embedding = mask_embedding(
            extended_s_i, extended_l_i, extended_h_i,
            hide_i, extended_mask_i,
            alpha_h, alpha_l, alpha_o, alpha_m,
            clip_model, preprocess
        )

        my_colors.append(new_colors_i)
        my_masks.append(new_masks_i)
        my_points.append(new_points_i)
        my_embeddings.append(embedding)

        # Cleanup
        del embedding
        del extended_s_i, extended_l_i, extended_h_i, hide_i, extended_mask_i

    # Flatten per-frame lists into single lists
    final_masks = []
    final_points = []
    final_embeddings = []
    final_colors = []
    for i in range(len(my_masks)):
        for j in range(len(my_masks[i])):
            final_masks.append(my_masks[i][j])
            final_points.append(my_points[i][j])
            final_embeddings.append(my_embeddings[i][j])
            final_colors.append(my_colors[i][j])

    return final_colors, final_masks, final_points, final_embeddings


def grid_indices(points, params, res):
    """
    Convert 3D points into voxel grid indices.
    """
    pts = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(pts)
    indices[:, 0] = ((pts[:, 0] - x_min) / res).long()
    indices[:, 1] = ((pts[:, 1] - y_min) / res).long()
    indices[:, 2] = ((pts[:, 2] - z_min) / res).long()
    return indices.long()


def voxelize_batch(points, X, Y, Z, params):
    """
    Convert a point cloud into a (X,Y,Z) voxel grid.
    Output:
      voxel: binary occupancy grid.
    """
    voxel = torch.zeros((X, Y, Z), device=device)
    indices = grid_indices(points, params, resolution)
    voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
    return voxel


def geometry_overlap(points, resolution, overlap_thr1, overlap_thr2):
    """
    Compute pairwise geometric overlap between all point clouds.
    - points: list of point clouds (each (N_i, 3))
    - resolution: voxel size
    - overlap_thr1: minimum overlap ratio per object
    - overlap_thr2: max allowed difference between overlaps_i and overlaps_j
    Returns:
      adjacency: NxN matrix, adjacency[i,j]=1 if two objects overlap enough.
    """
    num_points = len(points)
    adjacency = np.zeros((num_points, num_points))
    x_min = np.zeros(num_points)
    y_min = np.zeros(num_points)
    z_min = np.zeros(num_points)
    x_max = np.zeros(num_points)
    y_max = np.zeros(num_points)
    z_max = np.zeros(num_points)

    # Compute bounding boxes for each point cloud
    for i in range(num_points):
        x_min[i] = np.min(points[i][:, 0])
        y_min[i] = np.min(points[i][:, 1])
        z_min[i] = np.min(points[i][:, 2])
        x_max[i] = np.max(points[i][:, 0])
        y_max[i] = np.max(points[i][:, 1])
        z_max[i] = np.max(points[i][:, 2])

    # Compare all pairs
    for i in range(num_points):
        print(i, ' point clouds processed from ', num_points, '!')
        for j in range(num_points):
            # Quickly discard non-overlapping AABBs
            if ((x_min[i] > x_max[j] or x_max[i] < x_min[j]) or
                (y_min[i] > y_max[j] or y_max[i] < y_min[j]) or
                (z_min[i] > z_max[j] or z_max[i] < z_min[j])):
                continue

            # Build union bounding box with small padding
            my_x_min = np.min([x_min[i], x_min[j]]) - 0.2
            my_y_min = np.min([y_min[i], y_min[j]]) - 0.2
            my_z_min = np.min([z_min[i], z_min[j]]) - 0.2
            my_x_max = np.max([x_max[i], x_max[j]]) + 0.2
            my_y_max = np.max([y_max[i], y_max[j]]) + 0.2
            my_z_max = np.max([z_max[i], z_max[j]]) + 0.2
            X = int((my_x_max - my_x_min) / resolution)
            Y = int((my_y_max - my_y_min) / resolution)
            Z = int((my_z_max - my_z_min) / resolution)
            params = (my_x_min, my_y_min, my_z_min, my_x_max, my_y_max, my_z_max)

            if X <= 0 or Y <= 0 or Z <= 0:
                continue

            # Voxelize both point clouds into same grid
            v_i = voxelize_batch(points[i], X, Y, Z, params)
            v_j = voxelize_batch(points[j], X, Y, Z, params)

            overlaps = torch.sum(v_i * v_j)
            overlaps_i = overlaps / (torch.sum(v_i) + 1e-6)
            overlaps_j = overlaps / (torch.sum(v_j) + 1e-6)

            # If both overlap ratios are high and similar -> mark as adjacent
            if (overlaps_i > overlap_thr1) and (overlaps_j > overlap_thr1) and \
                    (torch.abs(overlaps_j - overlaps_i) < overlap_thr2):
                adjacency[i, j] = 1

            del v_j, overlaps_i, overlaps_j, overlaps
            torch.cuda.empty_cache()

    return adjacency


def merge_points(adj_matrix):
    """
    Given adjacency matrix, find connected components (merged object clusters).
    """
    G = nx.from_numpy_array(adj_matrix)
    components = list(nx.connected_components(G))
    return components


def mask_embedding(object_extend_s,
                   object_extend_l,
                   object_extend_h,
                   object_extend_hide,
                   object_extend_mask,
                   alpha_h,
                   alpha_l,
                   alpha_o,
                   alpha_m,
                   clip_model,
                   clip_preprocess):
    """
    Compute a final embedding for an object by combining multiple views:
      final = alpha_h*E_h + alpha_l*E_l + E_s - alpha_o*E_hide + alpha_m*E_mask
    """
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    embeddings_s = torch.tensor(clip_image(object_extend_s, clip_model, clip_preprocess)).to(dev)
    embeddings_l = torch.tensor(clip_image(object_extend_l, clip_model, clip_preprocess)).to(dev)
    embeddings_h = torch.tensor(clip_image(object_extend_h, clip_model, clip_preprocess)).to(dev)
    embeddings_mask = torch.tensor(clip_image(object_extend_mask, clip_model, clip_preprocess)).to(dev)
    embeddings_hide = torch.tensor(clip_image(object_extend_hide, clip_model, clip_preprocess)).to(dev)

    final_embedding = alpha_h * embeddings_h + alpha_l * embeddings_l + \
        embeddings_s - alpha_o * embeddings_hide + alpha_m * embeddings_mask

    return final_embedding.cpu().numpy()


def points_to_grid(points, resolution):
    """
    Snap 3D points to a 3D grid with cell size = resolution.
    """
    if points.shape[0] == 0:
        return points
    converted_points = ((points / resolution).astype(int).astype(np.float32)) * resolution
    return converted_points


def object_embeddings(components, points, embeddings, colors):
    """
    Aggregate points, colors and embeddings over connected components.
    - components: list of sets of point-cloud indices
    Returns:
      final_colors: list of stacked colors per component
      final_points: list of stacked points per component
      final_embeddings: mean embedding per component
      count: number of sub-objects merged into each component
    """
    final_points = []
    final_colors = []
    final_embeddings = []
    count = []
    for i in range(len(components)):
        new_points = []
        new_colors = []
        new_embeddings = []
        for j in components[i]:
            new_points.append(points[j])
            new_embeddings.append(embeddings[j])
            new_colors.append(colors[j])
        final_colors.append(np.vstack(new_colors))
        final_points.append(np.vstack(new_points))
        final_embeddings.append(np.mean(np.array(new_embeddings), axis=0))
        count.append(len(new_embeddings))
    return final_colors, final_points, final_embeddings, count


def objects_class(embeddings, classes):
    """
    Given embeddings (e.g. similarity scores) and a list of classes,
    choose the class with the maximum score for each object.
    """
    final_classes = []
    for i in range(len(embeddings)):
        final_index = np.argsort(embeddings[i])
        final_class_obj = np.array(classes)[final_index]
        final_classes.append(final_class_obj[-1])
    return final_classes


def main():
    """
    Main entry point:
    - parse args
    - load config
    - run mask detection and 3D grouping
    - save maps: point -> object_id and object_id -> embedding
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("--scene", type=str)
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--clip_model", type=str, default='EVA02-L-14-336')
    parser.add_argument("--path", type=str, default='')
    parser.add_argument("--dataset", type=str)
    args = parser.parse_args()

    path = args.path
    dataset = args.dataset
    scene = args.scene
    dataset_path = args.dataset_path
    clip_name = args.clip_model

    # CLIP model weights and SAM checkpoint
    clip_path = os.path.join(path, 'models/open_clip_pytorch_model.bin')
    sam_checkpoint = os.path.join(path, 'models/sam_vit_h_4b8939.pth')  # change name if needed

    # Create CLIP model & transforms
    clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_name, pretrained=None)
    clip_model.load_state_dict(torch.load(clip_path), strict=False)

    # Load config file and last_idx (#frames)
    if dataset == 'Replica':
        last_idx = 2000
        with open(os.path.join(path, "core_configs/config_Replica.yaml"), "r") as f:
            config = yaml.safe_load(f)
    else:
        last_idx = len(os.listdir(os.path.join(dataset_path, scene, 'color')))
        with open(os.path.join(path, "core_configs/config_ScanNet.yaml"), "r") as f:
            config = yaml.safe_load(f)

    global resolution
    # Read important thresholds and parameters from YAML
    geometery_overlap_thr1 = config['geometery_overlap_thr1']
    geometery_overlap_thr2 = config['geometery_overlap_thr2']
    resolution = config['resolution']
    remove_thr = config['remove_thr']
    side_thr = config['side_thr']
    iou_remove_thr = config['iou_remove_thr']
    ratio_thr = config['ratio_thr']
    extension_ratio_hide = config['extension_ratio_hide']
    extension_ratio_s = config['extension_ratio_s']
    extension_ratio_l = config['extension_ratio_l']
    extension_ratio_h = config['extension_ratio_h']
    area_threshold = config['area_threshold']
    alpha_h = config['alpha_h']
    alpha_l = config['alpha_l']
    alpha_o = config['alpha_o']
    alpha_m = config['alpha_m']

    # Path to the scene data
    scene_path = os.path.join(dataset_path, scene) + '/'

    # Load intrinsics and depth scale
    if dataset == 'Replica':
        camera_intristics, scale = _load_depth_intrinsics(
            path=os.path.join(dataset_path, 'cam_params.json'),
            dataset=dataset
        )
    else:
        camera_intristics, scale = _load_depth_intrinsics(
            path=os.path.join(dataset_path, scene, 'intrinsic/intrinsic_depth.txt'),
            dataset=dataset
        )

    # Ensure embeddings directory exists
    os.makedirs(os.path.join(path, 'embeddings'), exist_ok=True)

    # Run the full detection+embedding pipeline
    colors, masks, points, embeddings = masks_detection(
        extension_ratio_hide=extension_ratio_hide,
        extension_ratio_s=extension_ratio_s,
        extension_ratio_l=extension_ratio_l,
        extension_ratio_h=extension_ratio_h,
        path=scene_path,
        last_idx=last_idx,
        step=15,
        remove_iou=iou_remove_thr,
        remove_side_thr=remove_thr,
        side_thr=side_thr,
        dataset=dataset,
        area_threshold=area_threshold,
        ratio_thr=ratio_thr,
        camera_intristics=camera_intristics,
        scale=scale,
        resolution=resolution,
        clip_model=clip_model,
        preprocess=preprocess,
        alpha_h=alpha_h,
        alpha_l=alpha_l,
        alpha_o=alpha_o,
        alpha_m=alpha_m,
        sam_checkpoint=sam_checkpoint
    )

    if len(points) == 0:
        print("No points found – nothing to save.")
        return

    # Compute adjacency matrix based on 3D geometric overlap
    adjacency = geometry_overlap(points, resolution, geometery_overlap_thr1, geometery_overlap_thr2)

    # Merge objects into connected components
    components = merge_points(adjacency)

    # Aggregate embeddings & points for merged objects
    new_colors, new_points, new_embeddings, count = object_embeddings(components, points, embeddings, colors)

    # DataFrame mapping 3D points -> object ids
    df_points_to_ids = pd.DataFrame(columns=['x', 'y', 'z', 'Object id'])

    # Dictionary mapping object id -> embedding and count
    df_ids_to_embeddings = {}
    for i in range(len(new_points)):
        pts = np.unique(new_points[i], axis=0)
        temp = pd.DataFrame(columns=['x', 'y', 'z', 'Object id'])
        temp['x'] = pts[:, 0]
        temp['y'] = pts[:, 1]
        temp['z'] = pts[:, 2]
        temp['Object id'] = i

        df_ids_to_embeddings[i] = {
            'embedding': list(new_embeddings[i].astype(float)),
            'count': count[i]
        }
        df_points_to_ids = pd.concat([df_points_to_ids, temp], ignore_index=True)

    # Save embeddings and point-to-id mapping
    with open(os.path.join(path, 'embeddings', scene + '_ids_to_embeddings_sam_scannet.json'), "w") as json_file:
        json.dump(df_ids_to_embeddings, json_file, indent=4)

    df_points_to_ids.to_csv(os.path.join(path, 'embeddings', scene + '_points_to_ids_sam_scannet.csv'), index=False)


if __name__ == '__main__':
    main()
