import os
import numpy as np
import yaml
from PIL import Image

import cv2
import argparse
import json
import torch
import gc
import copy
import pandas as pd
import open_clip
from sklearn.cluster import DBSCAN
from semantic_sam import prepare_image, plot_results, build_semantic_sam, SemanticSamAutomaticMaskGenerator,SemanticSAMPredictor
import networkx as nx
import psutil

device = "cuda" if torch.cuda.is_available() else "cpu"


def _load_pose(path, idx, dataset):
    """
    Load camera pose (4x4 transformation matrix) for a given frame index.
    Supports Replica and ScanNet datasets.
    """
    if(dataset=='Replica'):
        # For Replica, all poses are in a single 'traj.txt' file
        path = os.path.join(path, "traj.txt")
        with open(path, "r") as file:
            lines = file.readlines()
            if 0 <= idx < len(lines):
                line = lines[idx]
                # each line has 16 values (4x4 matrix flattened row-wise)
                values = [float(val) for val in line.split()]
                transformation_matrix = np.array(values).reshape((4, 4))
                return transformation_matrix
    elif(dataset=='ScanNet'):
        # For ScanNet, each pose is stored in a separate txt file: "<idx>.txt"
        path = os.path.join(path, str(idx)+'.txt')
        transformation_matrix = np.loadtxt(path).reshape(4, 4)
        return transformation_matrix
        
def _load_depth_intrinsics(path,dataset):
    """
    Load depth camera intrinsics and scale factor for depth values.
    Different format for Replica vs ScanNet.
    """
    if(dataset=='Replica'):
        # Replica intrinsics and scale come from a JSON file
        with open(path, "r") as file:
            data = json.load(file)
            camera_params = data.get("camera")
            if camera_params:
                w = camera_params.get("w")
                h = camera_params.get("h")
                fx = camera_params.get("fx")
                fy = camera_params.get("fy")
                cx = camera_params.get("cx")
                cy = camera_params.get("cy")
                scale = camera_params.get("scale")
                # Camera intrinsic matrix K
                K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
                K = np.array(K)
                return K, scale
    elif(dataset=='ScanNet'):
        # ScanNet intrinsics: already in matrix form
        intrinsic_depth = np.loadtxt(path)
        # Depth values are typically in millimeters -> convert to meters by scale=1000
        scale = 1000.0
        return intrinsic_depth, scale
        
def overlap(mask_i,mask_j,iom_threshold):
    """
    Compute 'intersection over minimum' (IoM) between mask_i and mask_j.
    Return:
      0  -> no significant overlap
      1  -> mask_i mostly inside mask_j
      2  -> mask_j mostly inside mask_i
    """
    intersection = np.sum(np.multiply(mask_i,mask_j))
    sum_i = np.sum(mask_i)
    sum_j = np.sum(mask_j)
    if(sum_i>sum_j):
        # IoM denominator = smaller mask
        iom = intersection/sum_j
        if(iom>iom_threshold):
            return 2  # j is inside i
        else:
            return 0
    else:
        iom = intersection/sum_i
        if(iom>iom_threshold):
            return 1  # i is inside j
        else:
            return 0
        
def remove_overlapped_masks(results,iom_threshold,area_threshold):
    """
    Remove overlapping and tiny masks.
    - results: list of 2D binary masks (numpy arrays).
    - iom_threshold: IoM threshold for overlap pruning.
    - area_threshold: minimum area ratio to keep a mask.
    Returns:
      torch.Tensor of kept masks [N, H, W]
      torch.Tensor of corresponding boxes [N, 4] (xmin, ymin, xmax, ymax)
    """
    mask_shape = results[0].shape
    boxes = torch.zeros(len(results),4)
    new_masks = torch.zeros(len(results),mask_shape[0],mask_shape[1])
    removed = []
    for i in range(len(results)): 
        remove_i = []
        remove_mask_i = []
        mask_i = results[i].astype(int)
        
        # Area ratio relative to full image
        area_ratio = np.sum(mask_i)/(np.shape(mask_i)[0]*np.shape(mask_i)[1])
        
        if(area_ratio<area_threshold):
            # mask too small -> remove
            removed.append(i)
        
        else:    
            # Check overlap with all other masks
            for j in range(len(results)):	
                if(j!=i):
                    mask_j = results[j].astype(int)
                    index = overlap(mask_i,mask_j,iom_threshold=iom_threshold)
                    # If j is inside i, we remove j from i
                    if(index==2):
                        if(not(j in remove_i)):
                            remove_i.append(j)
                            remove_mask_i.append(mask_j)
            # Subtract overlapped masks from mask_i
            for mask in remove_mask_i:
                mask_i = mask_i*(1-mask)
            if(np.sum(mask_i)>0):
                new_masks[i] = torch.tensor(mask_i)

                # Compute bounding box from remaining mask_i
                rows, cols = np.where(mask_i)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()
                box = [min_col,min_row,max_col,max_row]
                boxes[i] = torch.tensor(box)

    the_masks = []
    the_boxes = []
    # Collect only non-removed and non-empty masks
    for i in range(new_masks.size()[0]): 
        mask = new_masks[i]
        box = boxes[i]

        if(not(i in removed) and torch.sum(mask)>0):
            the_masks.append(mask)
            the_boxes.append(box)

    return torch.stack(the_masks), torch.stack(the_boxes)

def remove_tiny_masks(masks,boxes,area_threshold):
    """
    Remove masks whose area ratio is below area_threshold.
    Input:
      masks: list of (H, W) tensors
      boxes: list of [4] tensors
    """
    the_masks = []
    the_boxes = []
    for i in range(len(masks)):

        mask = masks[i]
        box = boxes[i]
        area_ratio = torch.sum(mask)/((mask.size())[0]*(mask.size())[1])
            
        if(area_ratio>area_threshold):
            the_masks.append(mask)
            the_boxes.append(box)

    return the_masks, the_boxes

def dbscan_mask_denoise(mask, eps, min_samples):
    """
    Apply 2D DBSCAN on mask pixels to remove noise and split into clusters.
    Returns list of binary masks (one per cluster).
    """
    # Coordinates of non-zero pixels
    coords = np.column_stack(np.nonzero(mask))  # shape: (N, 2)

    if coords.shape[0] == 0:
        # no pixels -> return original mask (or empty)
        return mask

    # Cluster using DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(coords)
    
    labels = clustering.labels_  # -1 indicates noise

    # Build a label image initialized with -1
    cleaned_mask = -1*np.ones_like(mask)
    for cluster_id in np.unique(labels):
        if cluster_id == -1:
            continue  # skip noise cluster
        cluster_points = coords[labels == cluster_id]
        cleaned_mask[cluster_points[:, 0], cluster_points[:, 1]] = cluster_id
    
    final_masks = []
    # Convert each cluster to an individual binary mask
    for i in range(len(np.unique(cleaned_mask))):
        if(i!=-1):
            final_masks.append((cleaned_mask==i).astype(int))        

    return final_masks


def dbscan_3d(points, colors, eps, min_samples):
    """
    Apply DBSCAN in 3D on point cloud to find a dominant cluster.
    Returns:
      cluster_points, cluster_colors, threshold (cluster_size / total_size)
      If memory usage estimate too high, returns empty and threshold -1.
    """
    points = points.astype(np.float16)
    dtype = type(points[0])
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    # Memory estimation for all pairwise distances (worst-case)
    n_samples = np.shape(points)[0]
    bytes_per_entry = np.dtype(dtype).itemsize
    total_bytes = n_samples ** 2 * bytes_per_entry
    # If estimated memory exceeds ~600 GB, skip
    if(total_bytes / (1024 ** 3)>600):
        return [], [], -1

    # Run 3D DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
    labels = clustering.labels_

    # Here they take only the cluster with label 0 (assumed main cluster)
    cluster_points = points[labels == 0]
    cluster_colors = colors[labels == 0]
    threshold =  cluster_points.shape[0]/points.shape[0]

    return cluster_points, cluster_colors, threshold

def remove_side_masks(masks,boxes,remove_thr,side_thr,ratio_thr):
    """
    Remove masks that lie mostly on the image borders (side artifacts).
    - remove_thr: ratio of mask pixels near edge needed to remove.
    - side_thr: pixel distance from edge to be considered 'side region'.
    - ratio_thr: overall area ratio threshold to ignore huge masks.
    """
    removed = []
    mask_shape = masks[0].size()

    for i in range(len(masks)):
        mask = masks[i]
        H, W = mask.shape
    
        # Coordinate grid
        y_coords = torch.arange(H).view(-1, 1).expand(H, W)
        x_coords = torch.arange(W).view(1, -1).expand(H, W)
        
        # Distance to image borders
        dist_top = y_coords
        dist_bottom = H - 1 - y_coords
        dist_left = x_coords
        dist_right = W - 1 - x_coords

        dist_to_edge = torch.minimum(torch.minimum(dist_top, dist_bottom),
                                    torch.minimum(dist_left, dist_right))
        
        # Edge region mask (within side_thr pixels from any border)
        edge_mask = (dist_to_edge < side_thr).long()
        
        # Intersect region of interest with object mask
        final_mask = mask * edge_mask
        
        # Number of 'edge' pixels inside mask
        count = final_mask.sum().item()
        sum_mask = mask.sum().item()

        ratio = count/sum_mask  # fraction of mask at border
        ratio1 = sum_mask/(final_mask.size()[0]*final_mask.size()[1])  # total area ratio

        # If object mostly at the border and relatively small, remove it
        if(ratio>remove_thr and ratio1<ratio_thr):
            removed.append(i)

    new_masks = []
    new_boxes = []
    j = 0
    # Keep only non-removed masks
    for i in range(len(masks)): 
        mask = masks[i]
        box = boxes[i]
        if(not(i in removed)):
            new_masks.append(np.array(mask))
            new_boxes.append(np.array(box))
            j+=1
    return new_masks, new_boxes

def extend_images(image, boxes, masks, extension_ratio,hide_mask=False,hide_others=False):
    """
    Crop & resize regions around each bounding box, with optional masking:
      - hide_mask=True: hide the object itself (zero out its region).
      - hide_others=True: keep only the object, hide everything else.
    Then extend box by 'extension_ratio' and resize to fixed sizes.
    Returns:
      list of cropped images, list of mask coverage ratios per crop.
    """
    extended_images = []
    ratios = []
    for i in range(len(boxes)):
        
        new_image = copy.deepcopy(image)
        if(hide_mask):
            # Keep everything except the mask (set masked areas to zero)
            new_image = (np.array(1-masks[i])[:, :, np.newaxis])*new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)
        if(hide_others):
            # Keep only the masked object, set others to zero
            new_image = (np.array(masks[i])[:, :, np.newaxis])*new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)

        # Center and half width/height
        center_x = ((boxes[i][2]+boxes[i][0])/2)
        center_y = ((boxes[i][3]+boxes[i][1])/2)
        width = ((boxes[i][2]-boxes[i][0])/2)
        height = ((boxes[i][3]-boxes[i][1])/2)

        # Extended crop coordinates
        new_x1 = int(center_x-extension_ratio*width)
        new_x2 = int(center_x+extension_ratio*width)
        new_y1 = int(center_y-extension_ratio*height)
        new_y2 = int(center_y+extension_ratio*height)
        x_margin = image.shape[1]
        y_margin = image.shape[0]

        # Clamp to image boundaries
        if(new_x1<0):
            new_x1 = 0
        if(new_x2>x_margin):
            new_x2 = x_margin
        if(new_y1<0):
            new_y1 = 0
        if(new_y2>y_margin):
            new_y2 = y_margin

        # Crop the extended region
        final_image = new_image[new_y1:new_y2,new_x1:new_x2]

        # Resize crop keeping approximate aspect ratio
        if(width>height):
            final_image = cv2.resize(final_image,(1200,int(1200/width*height)))
        else:
            final_image = cv2.resize(final_image,(int(800/height*width),800))
    
        extended_images.append(final_image)
        # Ratio of mask pixels inside the extended crop
        ratio = np.sum(masks[i])/((new_y2-new_y1)*(new_x2-new_x1))
        ratios.append(ratio.item())
    
    return extended_images, ratios

def grid_embedding(masks,embeddings):
    """
    Broadcast per-object embeddings to a spatial grid and
    collect embeddings at mask locations.
    - masks: [N, H, W]
    - embeddings: [N, D]
    Returns: flattened embeddings of all masked pixels.
    """
    embeddings = torch.tensor(embeddings)
    embeddings = embeddings.view(-1,1,1,embeddings.size()[1])
    embeddings = torch.Tensor.repeat(embeddings,(1,masks.size()[1],masks.size()[2],1))
    # Select embedding vectors where mask > 0
    embeddings = embeddings[masks>0]
    return embeddings


def point_cloud(depth, scale, camera_intristics,mask,camera_pose,colors,dataset):
    """
    Convert depth + intrinsics + camera pose + mask into a 3D point cloud.
    Returns:
      points: (N, 3) in world coords
      colors: (N, 3) RGB values
    Only points where mask>0 and depth>0 are kept.
    """
    mask = mask.to('cuda')
    camera_matrix = torch.tensor(camera_intristics).to('cuda')
    depth = torch.tensor(depth,dtype=torch.float32) .to('cuda')
    colors = torch.tensor(colors).to('cuda')
    camera_pose = torch.tensor(camera_pose).to('cuda')

    # Pixel grid (y, x)
    y, x = torch.meshgrid(torch.arange(depth.size()[0]),torch.arange(depth.size()[1]),indexing='ij')
    x = x.to('cuda')
    y = y.to('cuda')
        
    # If depth has extra channel, squeeze it
    if(depth.dim()==3):
        depth = depth[:,:,0]
        
    # Convert raw depth using scale (e.g., to meters)
    depth = depth.float() / scale
        
    # Mask out invalid depths
    depth_mask1 = (depth>0).long()
    mask = mask*depth_mask1
    mask = mask>0

    # Back-project to 3D camera coordinates
    X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
    Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
    Z = depth
        
    # Homogeneous coordinates
    points = torch.stack((X.view(-1), Y.view(-1), Z.view(-1),torch.ones_like(X.view(-1))),dim=-1)

    # Transform to world coordinates via camera pose
    points = torch.matmul(camera_pose,points.T)
    points = points.T

    # Keep only masked points
    points = points[mask.view(-1)]
    colors = colors[mask]
    colors = colors.view(-1,3)
    points = points.view(-1,4)
        
    return points[:,:3].cpu().numpy(), colors.cpu().numpy()

def clip_image(images,model,preprocess):
    """
    Encode a list of RGB images into CLIP embeddings using a provided model+preprocess.
    Returns: list of numpy arrays (one embedding per image).
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    embeddings = []
    for i in range(len(images)):
        processed_image = preprocess(Image.fromarray(images[i])).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(processed_image)
        embeddings.append(image_features.squeeze(0).cpu().numpy())
    return embeddings

def masks_detection(extension_ratio_hide,extension_ratio_s,extension_ratio_l,extension_ratio_h,path,last_idx,step,remove_iou,remove_side_thr,side_thr,dataset,area_threshold,ratio_thr,
                    camera_intristics,scale,resolution,clip_model,preprocess,alpha_h,alpha_l,alpha_o,alpha_m,semantic_sam_path):
    """
    Main per-frame processing pipeline:
    - Load RGB, depth, pose
    - Run Semantic SAM at multiple levels to get masks
    - Filter masks (overlap, sides, tiny areas)
    - Refine masks with 2D DBSCAN
    - Build 3D point cloud for each mask, cluster with 3D DBSCAN
    - Crop extended regions, compute CLIP embeddings via mask_embedding()
    Returns aggregated:
      my_colors, my_masks, my_points, my_embeddings
    """

    if(dataset=='Replica'):
        prefix_rgb = 'results/frame'
        prefix_depth = 'results/depth'
        pose_path = path
        # List all frame files starting with 'frame'
        image_list = [f for f in os.listdir(path+'results/')
        if f.startswith('frame') and os.path.isfile(os.path.join(path+'results/', f))]
    elif(dataset=='ScanNet'):
        # For ScanNet, 'color' and 'depth' folders are used
        prefix_rgb = 'color/'
        prefix_depth = 'depth/'
        pose_path = path+'pose/'
        image_list = os.listdir(path+prefix_rgb)

    # Build three Semantic SAM mask generators at different levels (scales)
    mask_generator1 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L',ckpt=semantic_sam_path),level=[3])
    mask_generator2 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L',ckpt=semantic_sam_path),level=[4])
    mask_generator3 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L',ckpt=semantic_sam_path),level=[6])

    my_colors = []
    my_masks = []
    my_points = []
    my_embeddings = []

    # Process every 'step'-th frame up to last_idx
    for i in range(0,last_idx,step):

        print('Processing frame {}/{}'.format(i,len(image_list)))
        image_path = image_list[i]
        
        if(dataset=='Replica'):
            # Extract numeric index from 'frameXXXX.jpg'
            idx_path = str(image_path[image_path.index('e')+1:image_path.index('.')])
        elif(dataset=='ScanNet'):
            # For ScanNet, file name before '.' is index
            idx_path = str(image_path[:image_path.index('.')])

        # Read RGB and depth
        rgb_i = cv2.imread(path+prefix_rgb+idx_path+'.jpg')
        if(dataset=='Replica'):
            depth_i = cv2.imread(path+prefix_depth+idx_path+'.png',cv2.IMREAD_UNCHANGED).astype(np.double)
        elif(dataset=='ScanNet'):
            depth_i = cv2.imread(path+prefix_depth+idx_path+'.png', cv2.IMREAD_UNCHANGED).astype(np.float32)
        
        # Load pose
        camera_pose_i = _load_pose(pose_path,int(idx_path),dataset)
        
        # Ensure RGB & depth have same spatial resolution
        if rgb_i.shape[:2] != depth_i.shape[:2]:
            rgb_i = cv2.resize(rgb_i, (depth_i.shape[1], depth_i.shape[0]), interpolation=cv2.INTER_LINEAR)

        # PIL image for Semantic SAM input
        im_i = Image.open(path+prefix_rgb+idx_path+'.jpg')
        im_i = im_i.resize((depth_i.shape[1],depth_i.shape[0]))

        results_i  = []
        original_image, input_image = prepare_image(image_pth=path+prefix_rgb+idx_path+'.jpg')

        # Level-3 masks
        the_masks1 = mask_generator1.generate(input_image)
        # Start with a full 1 mask for 'remaining_parts'
        remaining_parts = np.ones_like(the_masks1[0]['segmentation'].astype(int))
        remaining_parts = cv2.resize(remaining_parts, (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
        new_remaining_parts = copy.deepcopy(remaining_parts)

        # Filter out masks that do not cover enough unseen area
        for k in range(len(the_masks1)):
            mask = the_masks1[k]['segmentation']
            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            # Condition on overlap with remaining_parts (at least 65% of mask)
            if(np.sum(mask*new_remaining_parts)/np.sum(mask)>0.65):
                results_i.append(mask)

        # Remove overlapping & tiny masks, then side masks
        the_masks_i, boxes_i = remove_overlapped_masks(results_i,remove_iou,area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i,boxes_i,remove_side_thr,side_thr,ratio_thr)

        # Update remaining_parts using accepted masks
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask,dtype=np.int32))

        del the_masks1
        del the_masks_i
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)

        # Level-4 masks
        the_masks2 = mask_generator2.generate(input_image)
        for k in range(len(the_masks2)):
            mask = the_masks2[k]['segmentation']
            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            if(np.sum(mask*remaining_parts)/np.sum(mask)>0.45):
                results_i.append(mask)

        the_masks_i, boxes_i = remove_overlapped_masks(results_i,remove_iou,area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i,boxes_i,remove_side_thr,side_thr,ratio_thr)
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask,dtype=np.int32))

        del the_masks2
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)

        # Level-6 masks
        the_masks3 = mask_generator3.generate(input_image)
        for k in range(len(the_masks3)):
            mask = the_masks3[k]['segmentation']
            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            if(np.sum(mask*remaining_parts)/np.sum(mask)>0.35):
                results_i.append(mask)

        del mask
        del the_masks3
        gc.collect()

        # Final overlap & side filtering
        the_masks_i, boxes_i = remove_overlapped_masks(results_i,remove_iou,area_threshold)
        the_masks_i, boxes_i = remove_side_masks(the_masks_i,boxes_i,remove_side_thr,side_thr,ratio_thr)

        masks_i = []
        boxes_i = []

        # For each mask, denoise with 2D DBSCAN and compute bounding boxes
        for k in range(len(the_masks_i)):
            mask = np.array(the_masks_i[k])
            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            new_masks = dbscan_mask_denoise(mask,7,80)

            for r in range(len(new_masks)):
                my_mask = new_masks[r]
                if(np.sum(my_mask) > 0):
                    masks_i.append(torch.tensor(my_mask))
                    rows, cols = np.where(my_mask>0)
                    min_row, max_row = rows.min(), rows.max()
                    min_col, max_col = cols.min(), cols.max()
                    box = [min_col,min_row,max_col,max_row]
                    boxes_i.append(torch.tensor(box))

            del new_masks
            gc.collect()
	    
        # Remove very small masks
        masks_i,boxes_i = remove_tiny_masks(masks_i,boxes_i,area_threshold)

        new_masks_i = []
        new_boxes_i = []
        new_points_i = []
        new_colors_i = []
        thresholds = []

        # For each remaining mask, build 3D point cloud and cluster in 3D
        for k in range(len(masks_i)):

            points, colors = point_cloud(depth_i, scale, camera_intristics, masks_i[k], camera_pose_i, rgb_i, dataset)
            # Snap points to 3D grid (quantization)
            points = points_to_grid(points, resolution)
           
            if(points.shape[0]>0):
                # DBSCAN in 3D
                points, colors, threshold = dbscan_3d(points, colors,0.15, 200)
            else:
                threshold = 0.0
         
            # Only keep clusters that are large enough (threshold > 0.8)
            if(threshold>0.8):

                thresholds.append(threshold)
                the_mask = masks_i[k]
                rows, cols = np.where(the_mask > 0)
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()

                box = [min_col, min_row, max_col, max_row]
                new_masks_i.append(masks_i[k].cpu().numpy())
                new_boxes_i.append(np.array(box))
                new_points_i.append(points)
                new_colors_i.append(colors)

            del points
            del colors
            gc.collect()

        # Build different visual crops (small/large/hide/etc.) to feed CLIP
        extended_s_i, ratio_s_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_s)
        extended_l_i, ratio_l_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_l)
        extended_h_i, ratio_h_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_h)
        hide_i, ratios_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_hide,True)
        extended_mask_i, ratio_m_i = extend_images(rgb_i,new_boxes_i,new_masks_i,1.0,False,True)

        # Get combined embedding per object
        embedding = mask_embedding(extended_s_i,extended_l_i,extended_h_i,hide_i,extended_mask_i,alpha_h,alpha_l,alpha_o,alpha_m,clip_model,preprocess)
	
        my_colors.append(new_colors_i)
        my_masks.append(new_masks_i)
        my_points.append(new_points_i)
        my_embeddings.append(embedding)

        # Cleanup
        del embedding
        del extended_s_i
        del extended_l_i
        del extended_h_i
        del hide_i
        del extended_mask_i

    # Flatten per-frame lists into single lists
    final_masks = []
    final_points = []
    final_embeddings = []
    final_colors = []
    for i in range(len(my_masks)):
        for j in range(len(my_masks[i])):
            final_masks.append(my_masks[i][j])
            final_points.append(my_points[i][j])
            final_embeddings.append(my_embeddings[i][j])
            final_colors.append(my_colors[i][j])
    return final_colors,final_masks,final_points,final_embeddings

def grid_indices(points,params,res):
    """
    Convert 3D points into voxel grid indices.
    NOTE: this function assumes a global 'device' & 'resolution' variable.
    """
    points = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(torch.tensor(points))
    indices[:,0] = ((points[:,0]-x_min)/res).long()
    indices[:,1] = ((points[:,1]-y_min)/res).long()
    indices[:,2] = ((points[:,2]-z_min)/res).long()
    return indices.long()


def voxelize_batch(points,X,Y,Z,params):
    """
    Convert a point cloud into a (X,Y,Z) voxel grid.
    Output:
      voxel: binary occupancy grid.
    NOTE: uses global 'device' and 'resolution'.
    """

    voxel = torch.zeros((X, Y, Z), device=device)
    indices = grid_indices(points, params, resolution)
    voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1

    return voxel

def geometry_overlap(points,resolution,overlap_thr1,overlap_thr2):
    """
    Compute pairwise geometric overlap between all point clouds.
    - points: list of point clouds (each (N_i, 3))
    - resolution: voxel size
    - overlap_thr1: minimum overlap ratio per object
    - overlap_thr2: max allowed difference between overlaps_i and overlaps_j
    Returns:
      adjacency: NxN matrix, adjacency[i,j]=1 if two objects overlap enough.
    """
    num_points = len(points)
    adjacency = np.zeros((num_points, num_points))
    x_min = np.zeros(num_points)
    y_min = np.zeros(num_points)
    z_min = np.zeros(num_points)
    x_max = np.zeros(num_points)
    y_max = np.zeros(num_points)
    z_max = np.zeros(num_points)

    # Compute bounding boxes for each point cloud
    for i in range(num_points):
        x_min[i] = np.min(points[i][:, 0])
        y_min[i] = np.min(points[i][:, 1])
        z_min[i] = np.min(points[i][:, 2])
        x_max[i] = np.max(points[i][:, 0])
        y_max[i] = np.max(points[i][:, 1])
        z_max[i] = np.max(points[i][:, 2])

    # Compare all pairs
    for i in range(0, num_points):
        print(i,' point clouds processed from ',num_points,'!')

        for j in range(0, num_points):
            # Quickly discard non-overlapping AABBs
            if(not((x_min[i]>x_max[j] or x_max[i]<x_min[j]) or (y_min[i]>y_max[j] or y_max[j]<y_min[j]) or (z_min[i]>z_max[j] or z_max[i]<z_min[j]))):

                # Build union bounding box with small padding
                my_x_min = np.min([x_min[i], x_min[j]])-0.2
                my_y_min = np.min([y_min[i], y_min[j]])-0.2
                my_z_min = np.min([z_min[i], z_min[j]])-0.2
                my_x_max = np.max([x_max[i], x_max[j]])+0.2
                my_y_max = np.max([y_max[i], y_max[j]])+0.2
                my_z_max = np.max([z_max[i], z_max[j]])+0.2
                X = int((my_x_max - my_x_min) / resolution)
                Y = int((my_y_max - my_y_min) / resolution)
                Z = int((my_z_max - my_z_min) / resolution)
                params = (my_x_min, my_y_min, my_z_min, my_x_max, my_y_max, my_z_max)

                # Voxelize both point clouds into same grid
                v_i = voxelize_batch(points[i], X, Y, Z, params)
                v_j = voxelize_batch(points[j], X, Y, Z, params)

                overlaps = torch.sum(v_i*v_j)
                overlaps_i = overlaps / torch.sum(v_i)
                overlaps_j = overlaps / torch.sum(v_j)

                # If both overlap ratios are high and similar -> mark as adjacent
                if ( overlaps_i > overlap_thr1) and (overlaps_j > overlap_thr1) and (torch.abs(overlaps_j - overlaps_i) < overlap_thr2):
                    adjacency[i,j] = 1

                del v_j,overlaps_i,overlaps_j,overlaps
                torch.cuda.empty_cache()

    return adjacency

def merge_points(adj_matrix):
    """
    Given adjacency matrix, find connected components (merged object clusters).
    """
    G = nx.from_numpy_array(adj_matrix)
    components = list(nx.connected_components(G))
    return  components

def mask_embedding(object_extend_s,object_extend_l,object_extend_h,object_extend_hide,object_extend_mask,alpha_h,alpha_l,alpha_o,alpha_m,clip_model,clip_path):
    """
    Compute a final embedding for an object by combining multiple views:
    - small crop
    - large crop
    - heavily extended crop
    - crop with object hidden
    - mask-focused crop
    Weighted combination:
      final = alpha_h*E_h + alpha_l*E_l + E_s - alpha_o*E_hide + alpha_m*E_mask
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embeddings_s = torch.tensor(clip_image(object_extend_s,clip_model,clip_path)).to(device)
    embeddings_l = torch.tensor(clip_image(object_extend_l,clip_model,clip_path)).to(device)
    embeddings_h = torch.tensor(clip_image(object_extend_h,clip_model,clip_path)).to(device)
    embeddings_mask = torch.tensor(clip_image(object_extend_mask,clip_model,clip_path)).to(device)
    embeddings_hide = torch.tensor(clip_image(object_extend_hide,clip_model,clip_path)).to(device)
    final_embedding = alpha_h*embeddings_h + alpha_l*embeddings_l + embeddings_s - alpha_o*embeddings_hide + alpha_m*embeddings_mask

    return final_embedding.cpu().numpy()


def points_to_grid(points,resolution):
    """
    Snap 3D points to a 3D grid with cell size = resolution.
    """
    converted_points = ((points/resolution).astype(int).astype(np.float32))*resolution
    return converted_points

def object_embeddings(components,points,embeddings,colors):
    """
    Aggregate points, colors and embeddings over connected components.
    - components: list of sets of point-cloud indices
    Returns:
      final_colors: list of stacked colors per component
      final_points: list of stacked points per component
      final_embeddings: mean embedding per component
      count: number of sub-objects merged into each component
    """
    final_points = []
    final_colors = []
    final_embeddings = []
    count = []
    for i in range(len(components)):
        new_points = []
        new_colors = []
        new_embeddings = []
        for j in components[i]:
            new_points.append(points[j])
            new_embeddings.append(embeddings[j])
            new_colors.append(colors[j])
        final_colors.append(np.vstack(new_colors))
        final_points.append(np.vstack(new_points))
        final_embeddings.append(np.mean(np.array(new_embeddings),axis=0))
        count.append(len(new_embeddings))
    return  final_colors,final_points, final_embeddings, count

def objects_class(embeddings,classes):
    """
    Given embeddings (e.g. similarity scores) and a list of classes,
    choose the class with the maximum score for each object.
    """
    final_classes = []
    for i in range(len(embeddings)):
        final_index = np.argsort(embeddings[i])
        final_class_obj = np.array(classes)[final_index]
        final_classes.append(final_class_obj[-1])
    return final_classes

def main():
    """
    Main entry point:
    - parse args
    - load config
    - run mask detection and 3D grouping
    - save maps: point -> object_id and object_id -> embedding
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("--scene", type=str)
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--clip_model", type=str,default='EVA02-L-14-336')
    parser.add_argument("--path", type=str,default='')
    parser.add_argument("--dataset", type=str)
    args = parser.parse_args()

    path = args.path
    dataset = args.dataset
    scene = args.scene
    dataset_path = args.dataset_path
    clip_name = args.clip_model

    # CLIP model weights and Semantic SAM weights
    clip_path = path+'models/open_clip_pytorch_model.bin'
    semantic_sam_path = 'models/swinl_only_sam_many2many.pth'

    # Create CLIP model & transforms
    clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_name, pretrained=None)
    clip_model.load_state_dict(torch.load(clip_path), strict=False)

    # Load config file and last_idx (#frames)
    if(dataset=='Replica'):
        # last_idx = 2000
        last_idx = 2000
        # with open(path+"configs/config_Replica.yaml", "r") as f:
        #     config = yaml.safe_load(f)
        with open(path+"core_configs/config_Replica.yaml", "r") as f:
            config = yaml.safe_load(f)
    else:
        last_idx = len(os.listdir(dataset_path+ scene + '/color/'))
        # with open(path+"configs/config_ScanNet.yaml", "r") as f:
        #     config = yaml.safe_load(f)
        with open(path+"core_configs/config_ScanNet.yaml", "r") as f:
            config = yaml.safe_load(f)

    global resolution
    # Read important thresholds and parameters from YAML
    geometery_overlap_thr1 = config['geometery_overlap_thr1']
    geometery_overlap_thr2 = config['geometery_overlap_thr2']
    resolution = config['resolution']
    remove_thr = config['remove_thr']
    side_thr = config['side_thr']
    iou_remove_thr = config['iou_remove_thr']
    ratio_thr = config['ratio_thr']
    extension_ratio_hide = config['extension_ratio_hide']
    extension_ratio_s = config['extension_ratio_s']
    extension_ratio_l = config['extension_ratio_l']
    extension_ratio_h = config['extension_ratio_h']
    area_threshold = config['area_threshold']
    alpha_h = config['alpha_h']
    alpha_l = config['alpha_l']
    alpha_o = config['alpha_o']
    alpha_m = config['alpha_m']

    # Path to the scene data
    scene_path = dataset_path+scene+'/'

    # Load intrinsics and depth scale
    if(dataset=='Replica'):
        camera_intristics, scale = _load_depth_intrinsics(path=dataset_path+'/cam_params.json',dataset=dataset)
    else:
        camera_intristics, scale = _load_depth_intrinsics(path=dataset_path+scene+'/intrinsic/intrinsic_depth.txt',dataset=dataset)

    # Run the full detection+embedding pipeline
    colors, masks, points, embeddings = masks_detection(
        extension_ratio_hide=extension_ratio_hide,
        extension_ratio_s=extension_ratio_s,
        extension_ratio_l=extension_ratio_l,
        extension_ratio_h=extension_ratio_h,
        path=scene_path,
        last_idx=last_idx,
        step=15,
        remove_iou=iou_remove_thr,
        remove_side_thr=remove_thr,
        side_thr=side_thr,
        dataset=dataset,
        area_threshold=area_threshold,
        ratio_thr=ratio_thr,
        camera_intristics=camera_intristics,
        scale=scale,
        resolution=resolution,
        clip_model=clip_model,
        preprocess=preprocess,
        alpha_h=alpha_h,
        alpha_l=alpha_l,
        alpha_o=alpha_o,
        alpha_m=alpha_m,
        semantic_sam_path=semantic_sam_path
    )

    # Compute adjacency matrix based on 3D geometric overlap
    adjacency = geometry_overlap(points,resolution,geometery_overlap_thr1,geometery_overlap_thr2)

    # Merge objects into connected components
    components = merge_points(adjacency)

    # Aggregate embeddings & points for merged objects
    new_colors,new_points,new_embeddings,count = object_embeddings(components,points,embeddings,colors)

    # DataFrame mapping 3D points -> object ids
    df_points_to_ids = pd.DataFrame(columns=['x', 'y', 'z','Object id'])

    # Dictionary mapping object id -> embedding and count
    df_ids_to_embeddings = {}
    for i in range(len(new_points)):
        points = np.unique(new_points[i],axis=0)
        temp = pd.DataFrame(columns=['x', 'y', 'z','Object id'])
        temp['x'] = points[:,0]
        temp['y'] = points[:,1]
        temp['z'] = points[:,2]
        temp['Object id'] = i

        df_ids_to_embeddings[i] = {'embedding':list(new_embeddings[i].astype(float)),'count':count[i]}
        df_points_to_ids = pd.concat([df_points_to_ids,temp],ignore_index=True)

    # Save embeddings and point-to-id mapping
    with open(path+'embeddings/'+scene + '_ids_to_embeddings_scannet.json', "w") as json_file:
        json.dump(df_ids_to_embeddings, json_file, indent=4)

    df_points_to_ids.to_csv(path+'embeddings/'+scene+'_points_to_ids_scannet.csv', index=False)


if __name__ == '__main__':
    main()
