import os
import numpy as np
import yaml
from PIL import Image

import cv2
import argparse
import json
import torch
import gc
import copy
import pandas as pd
import open_clip
from sklearn.cluster import DBSCAN
from semantic_sam import prepare_image, plot_results, build_semantic_sam, SemanticSamAutomaticMaskGenerator,SemanticSAMPredictor
import networkx as nx
import psutil

def _load_pose(path, idx, dataset):

        if(dataset=='Replica'):
            path = os.path.join(path, "traj.txt")
            with open(path, "r") as file:
                lines = file.readlines()
                if 0 <= idx < len(lines):
                    line = lines[idx]
                    values = [float(val) for val in line.split()]
                    # Reshape the 16 values into a 4x4 matrix
                    transformation_matrix = np.array(values).reshape((4, 4))
                    return transformation_matrix
        elif(dataset=='ScanNet'):
            path = os.path.join(path, str(idx)+'.txt')
        
            transformation_matrix = np.loadtxt(path).reshape(4, 4)
            return transformation_matrix
        
def _load_depth_intrinsics(path,dataset):
        if(dataset=='Replica'):
            with open(path, "r") as file:
                data = json.load(file)
                camera_params = data.get("camera")
                if camera_params:
                    w = camera_params.get("w")
                    h = camera_params.get("h")
                    fx = camera_params.get("fx")
                    fy = camera_params.get("fy")
                    cx = camera_params.get("cx")
                    cy = camera_params.get("cy")
                    scale = camera_params.get("scale")
                    # Creating the camera matrix K
                    K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
                    K = np.array(K)
                    return K, scale
        elif(dataset=='ScanNet'):
            intrinsic_depth = np.loadtxt(path)
            # fx = intrinsic_depth[0, 0]
            # fy = intrinsic_depth[1, 1]
            # cx = intrinsic_depth[0, 2]
            # cy = intrinsic_depth[1, 2]  
            scale = 1000.0
            # Creating the camera matrix K
            # K = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
            # K = np.array(K)
            return intrinsic_depth, scale
        
def overlap(mask_i,mask_j,iom_threshold):
    intersection = np.sum(np.multiply(mask_i,mask_j))
    sum_i = np.sum(mask_i)
    sum_j = np.sum(mask_j)
    if(sum_i>sum_j):
        iom = intersection/sum_j
        if(iom>iom_threshold):
            return 2
        else:
            return 0
    else:
        iom = intersection/sum_i
        if(iom>iom_threshold):
            return 1
        else:
            return 0
        
def remove_overlapped_masks(results,iom_threshold,area_threshold):
    mask_shape = results[0].shape
    boxes = torch.zeros(len(results),4)
    new_masks = torch.zeros(len(results),mask_shape[0],mask_shape[1])
    removed = []
    for i in range(len(results)): 
        remove_i = []
        remove_mask_i = []
        mask_i = results[i].astype(int)
        
        area_ratio = np.sum(mask_i)/(np.shape(mask_i)[0]*np.shape(mask_i)[1])
        
        if(area_ratio<area_threshold):
            removed.append(i)
        
        else:    
            for j in range(len(results)):	
                if(j!=i):
                    mask_j = results[j].astype(int)
                    index = overlap(mask_i,mask_j,iom_threshold=iom_threshold)
                    # if(index==1):
                    #     if(not(i in removed)):
                    #         removed.append(i)
                    if(index==2):
                        if(not(j in remove_i)):
                            remove_i.append(j)
                            remove_mask_i.append(mask_j)
            for mask in remove_mask_i:
                mask_i = mask_i*(1-mask)
            if(np.sum(mask_i)>0):
                new_masks[i] = torch.tensor(mask_i)

                rows, cols = np.where(mask_i)
                # Bounding box coordinates
                min_row, max_row = rows.min(), rows.max()
                min_col, max_col = cols.min(), cols.max()

                box = [min_col,min_row,max_col,max_row]
                # box = results[0].boxes[i].xyxy
                boxes[i] = torch.tensor(box)
    the_masks = []
    the_boxes = []
    for i in range(new_masks.size()[0]): 
        mask = new_masks[i]
        box = boxes[i]

        if(not(i in removed) and torch.sum(mask)>0):
            the_masks.append(mask)
            the_boxes.append(box)

    return torch.stack(the_masks), torch.stack(the_boxes)

def remove_tiny_masks(masks,boxes,area_threshold):
    the_masks = []
    the_boxes = []
    for i in range(len(masks)):

        mask = masks[i]
        box = boxes[i]
        area_ratio = torch.sum(mask)/((mask.size())[0]*(mask.size())[1])
            
        if(area_ratio>area_threshold):
            
            the_masks.append(mask)
            the_boxes.append(box)

    return the_masks, the_boxes

def dbscan_mask_denoise(mask, eps, min_samples):
    # Extract coordinates of True pixels
    coords = np.column_stack(np.nonzero(mask))  # shape: (N, 2)

    if coords.shape[0] == 0:
        return mask  # no points, return original mask

    # Run DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(coords)
    
    # Get labels (-1 = noise)
    labels = clustering.labels_

    # Filter noise points
    cleaned_mask = -1*np.ones_like(mask)
    for cluster_id in np.unique(labels):
        if cluster_id == -1:
            continue  # skip noise
        cluster_points = coords[labels == cluster_id]
        cleaned_mask[cluster_points[:, 0], cluster_points[:, 1]] = cluster_id
    
    final_masks = []
    for i in range(len(np.unique(cleaned_mask))):
        if(i!=-1):
            final_masks.append((cleaned_mask==i).astype(int))        

    return final_masks


def dbscan_3d(points, colors, eps, min_samples):
    points = points.astype(np.float16)
    dtype = type(points[0])
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    # print(f"Memory usage: {mem_info.rss / (1024 ** 3):.2f} GB")
    n_samples = np.shape(points)[0]
    bytes_per_entry = np.dtype(dtype).itemsize
    total_bytes = n_samples ** 2 * bytes_per_entry
    if(total_bytes / (1024 ** 3)>600):
        return [], [], -1


    # in GB
    # Run DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)

    # Get labels (-1 = noise)
    labels = clustering.labels_

    # Filter noise points

    cluster_points = points[labels == 0]
    cluster_colors = colors[labels == 0]
    threshold =  cluster_points.shape[0]/points.shape[0]

    return cluster_points, cluster_colors, threshold

def remove_side_masks(masks,boxes,remove_thr,side_thr,ratio_thr):
    removed = []
    mask_shape = masks[0].size()

    for i in range(len(masks)):
        mask = masks[i]
        # print(mask.sum().item())
        # mask: boolean tensor of shape (H, W)
        H, W = mask.shape
    
        # Compute distance to the closest edge
        y_coords = torch.arange(H).view(-1, 1).expand(H, W)
        x_coords = torch.arange(W).view(1, -1).expand(H, W)
        
        dist_top = y_coords
        dist_bottom = H - 1 - y_coords
        dist_left = x_coords
        dist_right = W - 1 - x_coords

        dist_to_edge = torch.minimum(torch.minimum(dist_top, dist_bottom),
                                    torch.minimum(dist_left, dist_right))
        
        # Create mask of pixels closer to edge than threshold
        edge_mask = (dist_to_edge < side_thr).long()
        
        # Combine with original mask
        final_mask = mask * edge_mask
        
        # Count True pixels
        count = final_mask.sum().item()
        sum_mask = mask.sum().item()

        ratio = count/sum_mask
        ratio1 = sum_mask/(final_mask.size()[0]*final_mask.size()[1])

        if(ratio>remove_thr and ratio1<ratio_thr):
            removed.append(i)

    new_masks = []
    new_boxes = []
    j = 0
    for i in range(len(masks)): 
        mask = masks[i]
        box = boxes[i]
        if(not(i in removed)):
            new_masks.append(np.array(mask))
            new_boxes.append(np.array(box))
            j+=1
    return new_masks, new_boxes

def extend_images(image, boxes, masks, extension_ratio,hide_mask=False,hide_others=False):

    extended_images = []
    ratios = []
    for i in range(len(boxes)):
        
        new_image = copy.deepcopy(image)
        if(hide_mask):
            new_image = (np.array(1-masks[i])[:, :, np.newaxis])*new_image
            #new_image[boxes[i][1].cpu().long():boxes[i][3].cpu().long(),boxes[i][0].cpu().long():boxes[i][2].cpu().long()] = 0
            new_image = new_image.clip(0, 255).astype(np.uint8)
        if(hide_others):

            new_image = (np.array(masks[i])[:, :, np.newaxis])*new_image
            new_image = new_image.clip(0, 255).astype(np.uint8)

        center_x = ((boxes[i][2]+boxes[i][0])/2)
        center_y = ((boxes[i][3]+boxes[i][1])/2)
        width = ((boxes[i][2]-boxes[i][0])/2)
        height = ((boxes[i][3]-boxes[i][1])/2)
        new_x1 = int(center_x-extension_ratio*width)
        new_x2 = int(center_x+extension_ratio*width)
        new_y1 = int(center_y-extension_ratio*height)
        new_y2 = int(center_y+extension_ratio*height)
        x_margin = image.shape[1]
        y_margin = image.shape[0]

        if(new_x1<0):
            new_x1 = 0
        if(new_x2>x_margin):
            new_x2 = x_margin
        if(new_y1<0):
            new_y1 = 0
        if(new_y2>y_margin):
            new_y2 = y_margin

        # final_image = np.zeros((new_y2-new_y1,new_x2-new_x1,3))
        final_image = new_image[new_y1:new_y2,new_x1:new_x2]
        if(width>height):
            final_image = cv2.resize(final_image,(1200,int(1200/width*height)))
        else:
            final_image = cv2.resize(final_image,(int(800/height*width),800))
    
        
        extended_images.append(final_image)
        ratio = np.sum(masks[i])/((new_y2-new_y1)*(new_x2-new_x1))
        ratios.append(ratio.item())
    
    return extended_images, ratios

def grid_embedding(masks,embeddings):
    embeddings = torch.tensor(embeddings)
    embeddings = embeddings.view(-1,1,1,embeddings.size()[1])
    embeddings = torch.Tensor.repeat(embeddings,(1,masks.size()[1],masks.size()[2],1))
    embeddings = embeddings[masks>0]
    return embeddings


def point_cloud(depth, scale, camera_intristics,mask,camera_pose,colors,dataset):
        
        """
        This method should be implemented by subclasses to create a point cloud 
        from RGB-D images.
        """
        
        #mask = torch.tensor(mask).to('cuda')
        mask = mask.to('cuda')
        camera_matrix = torch.tensor(camera_intristics).to('cuda')
        depth = torch.tensor(depth,dtype=torch.float32) .to('cuda')

        colors = torch.tensor(colors).to('cuda')
        
        camera_pose = torch.tensor(camera_pose).to('cuda')
        y, x = torch.meshgrid(torch.arange(depth.size()[0]),torch.arange(depth.size()[1]),indexing='ij')
        x = x.to('cuda')
        y = y.to('cuda')
        # x = torch.Tensor.repeat(x.unsqueeze(0),(mask.size()[0],1,1)).to('cuda')
        # y = torch.Tensor.repeat(y.unsqueeze(0),(mask.size()[0],1,1)).to('cuda')
        
        # neglect points with depth = 0
        if(depth.dim()==3):
            depth = depth[:,:,0]
        
        depth = depth.float() / scale
        
        depth_mask1 = (depth>0).long()
        #depth_mask2 = (depth<3.0).long()
        mask = mask*depth_mask1
        #mask = mask*depth_mask2
        mask = mask>0

        X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
        Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
        Z = depth
        # convert to camera coordinate
        
        points = torch.stack((X.view(-1), Y.view(-1), Z.view(-1),torch.ones_like(X.view(-1))),dim=-1)

        points = torch.matmul(camera_pose,points.T)
        points = points.T
        points = points[mask.view(-1)]
        colors = colors[mask]
        colors = colors.view(-1,3)
        points = points.view(-1,4)
        
        return points[:,:3].cpu().numpy(), colors.cpu().numpy()

def clip_image(images,model,preprocess):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    embeddings = []
    for i in range(len(images)):
        processed_image = preprocess(Image.fromarray(images[i])).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(processed_image)

        embeddings.append(image_features.squeeze(0).cpu().numpy())
    return embeddings

def masks_detection(extension_ratio_hide,extension_ratio_s,extension_ratio_l,extension_ratio_h,path,last_idx,step,remove_iou,remove_side_thr,side_thr,dataset,area_threshold,ratio_thr,
                    camera_intristics,scale,resolution,clip_model,preprocess,alpha_h,alpha_l,alpha_o,alpha_m,semantic_sam_path):

    if(dataset=='Replica'):
        prefix_rgb = 'results/frame'
        prefix_depth = 'results/depth'
        pose_path = path
        image_list = [f for f in os.listdir(path+'results/')
        if f.startswith('frame') and os.path.isfile(os.path.join(path+'results/', f))]
    elif(dataset=='ScanNet'):
        # prefix_rgb = 'clean_images/'
        prefix_rgb = 'color/'
        prefix_depth = 'depth/'
        pose_path = path+'pose/'
        image_list = os.listdir(path+prefix_rgb)

    # semantic_sam_path = '/home/mamamin/Desktop/Robotics/Semantic_SAM/swinl_only_sam_many2many.pth'

    mask_generator1 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L',ckpt=semantic_sam_path),level=[3])
    mask_generator2 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L',ckpt=semantic_sam_path),level=[4])
    mask_generator3 = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L',ckpt=semantic_sam_path),level=[6])
    # mask_generator = SemanticSAMPredictor(build_semantic_sam(model_type='L', ckpt='/home/mamamin/Desktop/Robotics/Semantic_SAM/swinl_only_sam_many2many.pth'))
    my_colors = []
    my_masks = []
    my_points = []
    my_embeddings = []
    for i in range(0,last_idx,step):

        print('Processing frame {}/{}'.format(i,len(image_list)))
        image_path = image_list[i]
        
        if(dataset=='Replica'):
            idx_path = str(image_path[image_path.index('e')+1:image_path.index('.')])
            
        elif(dataset=='ScanNet'):
            idx_path = str(image_path[:image_path.index('.')])

        rgb_i = cv2.imread(path+prefix_rgb+idx_path+'.jpg')
        if(dataset=='Replica'):
            depth_i = cv2.imread(path+prefix_depth+idx_path+'.png',cv2.IMREAD_UNCHANGED).astype(np.double)
        
        elif(dataset=='ScanNet'):
            depth_i = cv2.imread(path+prefix_depth+idx_path+'.png', cv2.IMREAD_UNCHANGED).astype(np.float32)
        

        camera_pose_i = _load_pose(pose_path,int(idx_path),dataset)
        
        if rgb_i.shape[:2] != depth_i.shape[:2]:
            rgb_i = cv2.resize(rgb_i, (depth_i.shape[1], depth_i.shape[0]), interpolation=cv2.INTER_LINEAR)


        im_i = Image.open(path+prefix_rgb+idx_path+'.jpg')
        im_i = im_i.resize((depth_i.shape[1],depth_i.shape[0]))

        results_i  = []
        original_image, input_image = prepare_image(image_pth=path+prefix_rgb+idx_path+'.jpg')
        the_masks1 = mask_generator1.generate(input_image)
        remaining_parts = np.ones_like(the_masks1[0]['segmentation'].astype(int))
        remaining_parts = cv2.resize(remaining_parts, (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
        new_remaining_parts = copy.deepcopy(remaining_parts)
        for k in range(len(the_masks1)):
            mask = the_masks1[k]['segmentation']
            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            if(np.sum(mask*new_remaining_parts)/np.sum(mask)>0.65):
                results_i.append(mask)


        the_masks_i, boxes_i = remove_overlapped_masks(results_i,remove_iou,area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i,boxes_i,remove_side_thr,side_thr,ratio_thr)
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask,dtype=np.int32))

        del the_masks1
        del the_masks_i
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)
        the_masks2 = mask_generator2.generate(input_image)
        # the_masks3 = sam_model.predict(rgb_i)
        for k in range(len(the_masks2)):
            mask = the_masks2[k]['segmentation']
            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            if(np.sum(mask*remaining_parts)/np.sum(mask)>0.45):
                results_i.append(mask)
                # new_remaining_parts *= (1-mask)

        the_masks_i, boxes_i = remove_overlapped_masks(results_i,remove_iou,area_threshold)
        results_i, boxes_i = remove_side_masks(the_masks_i,boxes_i,remove_side_thr,side_thr,ratio_thr)
        for k in range(the_masks_i.size(0)):
            mask = the_masks_i[k]
            new_remaining_parts *= (1 - np.array(mask,dtype=np.int32))

        del the_masks2
        gc.collect()
        remaining_parts = copy.deepcopy(new_remaining_parts)
        the_masks3 = mask_generator3.generate(input_image)

        for k in range(len(the_masks3)):
            mask = the_masks3[k]['segmentation']
            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            if(np.sum(mask*remaining_parts)/np.sum(mask)>0.35):
                results_i.append(mask)

        del mask
        del the_masks3
        gc.collect()

        the_masks_i, boxes_i = remove_overlapped_masks(results_i,remove_iou,area_threshold)
        the_masks_i, boxes_i = remove_side_masks(the_masks_i,boxes_i,remove_side_thr,side_thr,ratio_thr)

        masks_i = []
        boxes_i = []

        for k in range(len(the_masks_i)):
            mask = np.array(the_masks_i[k])

            mask = cv2.resize(mask.astype(int),(rgb_i.shape[1],rgb_i.shape[0]),interpolation=cv2.INTER_NEAREST)
            new_masks = dbscan_mask_denoise(mask,7,80)

            for r in range(len(new_masks)):

                my_mask = new_masks[r]
                if(np.sum(my_mask) > 0):

                    masks_i.append(torch.tensor(my_mask))
                    rows, cols = np.where(my_mask>0)

                    # Bounding box coordinates
                    min_row, max_row = rows.min(), rows.max()
                    min_col, max_col = cols.min(), cols.max()

                    box = [min_col,min_row,max_col,max_row]
                    boxes_i.append(torch.tensor(box))

            del new_masks
            gc.collect()
	    
        masks_i,boxes_i = remove_tiny_masks(masks_i,boxes_i,area_threshold)
        new_masks_i = []
        new_boxes_i = []
        new_points_i = []
        new_colors_i = []
        thresholds = []


        for k in range(len(masks_i)):

            points, colors = point_cloud(depth_i, scale, camera_intristics, masks_i[k], camera_pose_i, rgb_i, dataset)
            points = points_to_grid(points, resolution)
           
            if(points.shape[0]>0):
                points, colors, threshold = dbscan_3d(points, colors,0.15, 200)
            else:
                threshold = 0.0
         
            if(threshold>0.8):

                    thresholds.append(threshold)

                    the_mask = masks_i[k]
                    rows, cols = np.where(the_mask > 0)
                    # Bounding box coordinates
                    min_row, max_row = rows.min(), rows.max()
                    min_col, max_col = cols.min(), cols.max()

                    box = [min_col, min_row, max_col, max_row]
                    new_masks_i.append(masks_i[k].cpu().numpy())
                    new_boxes_i.append(np.array(box))
                    new_points_i.append(points)
                    new_colors_i.append(colors)

            del points
            del colors
            gc.collect()

  

        extended_s_i, ratio_s_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_s)
        extended_l_i, ratio_l_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_l)
        extended_h_i, ratio_h_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_h)
        hide_i, ratios_i = extend_images(rgb_i,new_boxes_i,new_masks_i,extension_ratio_hide,True)
        extended_mask_i, ratio_m_i = extend_images(rgb_i,new_boxes_i,new_masks_i,1.0,False,True)

        embedding = mask_embedding(extended_s_i,extended_l_i,extended_h_i,hide_i,extended_mask_i,alpha_h,alpha_l,alpha_o,alpha_m,clip_model,preprocess)
	
        my_colors.append(new_colors_i)
        my_masks.append(new_masks_i)
        my_points.append(new_points_i)
        my_embeddings.append(embedding)
        del embedding
        del extended_s_i
        del extended_l_i
        del extended_h_i
        del hide_i
        del extended_mask_i
    final_masks = []
    final_points = []
    final_embeddings = []
    final_colors = []
    for i in range(len(my_masks)):
        for j in range(len(my_masks[i])):
            final_masks.append(my_masks[i][j])
            final_points.append(my_points[i][j])
            final_embeddings.append(my_embeddings[i][j])
            final_colors.append(my_colors[i][j])
    return final_colors,final_masks,final_points,final_embeddings

def grid_indices(points,params,res):
    points = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(torch.tensor(points))
    indices[:,0] = ((points[:,0]-x_min)/res).long()
    indices[:,1] = ((points[:,1]-y_min)/res).long()
    indices[:,2] = ((points[:,2]-z_min)/res).long()
    return indices.long()


def voxelize_batch(points,X,Y,Z,params):
    """Convert a batch of point clouds into [B, X*Y*Z] flattened voxel grids."""

    voxel = torch.zeros((X, Y, Z), device=device)
    indices = grid_indices(points, params, resolution)
    voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1

    return voxel

def geometry_overlap(points,resolution,overlap_thr1,overlap_thr2):
    num_points = len(points)
    # tot_points = np.vstack(points)
    adjacency = np.zeros((num_points, num_points))
    x_min = np.zeros(num_points)
    y_min = np.zeros(num_points)
    z_min = np.zeros(num_points)
    x_max = np.zeros(num_points)
    y_max = np.zeros(num_points)
    z_max = np.zeros(num_points)
    for i in range(num_points):
        x_min[i] = np.min(points[i][:, 0])
        y_min[i] = np.min(points[i][:, 1])
        z_min[i] = np.min(points[i][:, 2])
        x_max[i] = np.max(points[i][:, 0])
        y_max[i] = np.max(points[i][:, 1])
        z_max[i] = np.max(points[i][:, 2])

    for i in range(0, num_points):
        print(i,' point clouds processed from ',num_points,'!')

        for j in range(0, num_points):
            if(not((x_min[i]>x_max[j] or x_max[i]<x_min[j]) or (y_min[i]>y_max[j] or y_max[j]<y_min[j]) or (z_min[i]>z_max[j] or z_max[i]<z_min[j]))):

                my_x_min = np.min([x_min[i], x_min[j]])-0.2
                my_y_min = np.min([y_min[i], y_min[j]])-0.2
                my_z_min = np.min([z_min[i], z_min[j]])-0.2
                my_x_max = np.max([x_max[i], x_max[j]])+0.2
                my_y_max = np.max([y_max[i], y_max[j]])+0.2
                my_z_max = np.max([z_max[i], z_max[j]])+0.2
                X = int((my_x_max - my_x_min) / resolution)
                Y = int((my_y_max - my_y_min) / resolution)
                Z = int((my_z_max - my_z_min) / resolution)
                params = (my_x_min, my_y_min, my_z_min, my_x_max, my_y_max, my_z_max)
                v_i = voxelize_batch(points[i], X, Y, Z, params)
                v_j = voxelize_batch(points[j], X, Y, Z, params)
                overlaps = torch.sum(v_i*v_j)
                overlaps_i = overlaps / torch.sum(v_i)
                overlaps_j = overlaps / torch.sum(v_j)
                if ( overlaps_i > overlap_thr1) and (overlaps_j > overlap_thr1) and (torch.abs(overlaps_j - overlaps_i) < overlap_thr2):
                    adjacency[i,j] = 1
                del v_j,overlaps_i,overlaps_j,overlaps
                torch.cuda.empty_cache()

    return adjacency

def merge_points(adj_matrix):
    G = nx.from_numpy_array(adj_matrix)
    components = list(nx.connected_components(G))
    return  components

def mask_embedding(object_extend_s,object_extend_l,object_extend_h,object_extend_hide,object_extend_mask,alpha_h,alpha_l,alpha_o,alpha_m,clip_model,clip_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embeddings_s = torch.tensor(clip_image(object_extend_s,clip_model,clip_path)).to(device)
    embeddings_l = torch.tensor(clip_image(object_extend_l,clip_model,clip_path)).to(device)
    embeddings_h = torch.tensor(clip_image(object_extend_h,clip_model,clip_path)).to(device)
    embeddings_mask = torch.tensor(clip_image(object_extend_mask,clip_model,clip_path)).to(device)
    embeddings_hide = torch.tensor(clip_image(object_extend_hide,clip_model,clip_path)).to(device)
    final_embedding = alpha_h*embeddings_h + alpha_l*embeddings_l + embeddings_s - alpha_o*embeddings_hide + alpha_m*embeddings_mask

    return final_embedding.cpu().numpy()


def points_to_grid(points,resolution):

    converted_points = ((points/resolution).astype(int).astype(np.float32))*resolution
    return converted_points

def object_embeddings(components,points,embeddings,colors):
    final_points = []
    final_colors = []
    final_embeddings = []
    count = []
    for i in range(len(components)):
        new_points = []
        new_colors = []
        new_embeddings = []
        for j in components[i]:
            new_points.append(points[j])
            new_embeddings.append(embeddings[j])
            new_colors.append(colors[j])
        final_colors.append(np.vstack(new_colors))
        final_points.append(np.vstack(new_points))
        final_embeddings.append(np.mean(np.array(new_embeddings),axis=0))
        count.append(len(new_embeddings))
    return  final_colors,final_points, final_embeddings, count

def objects_class(embeddings,classes):
    final_classes = []
    for i in range(len(embeddings)):
        final_index = np.argsort(embeddings[i])
        final_class_obj = np.array(classes)[final_index]
        final_classes.append(final_class_obj[-1])
    return final_classes

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--scene", type=str)
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--clip_model", type=str,default='EVA02-L-14-336')
    parser.add_argument("--path", type=str,default='')
    parser.add_argument("--dataset", type=str)
    args = parser.parse_args()

    path = args.path
    dataset = args.dataset
    scene = args.scene
    dataset_path = args.dataset_path
    clip_name = args.clip_model

    # clip_name = 'EVA02-L-14-336'
    clip_path = path+'models/open_clip_pytorch_model.bin'
    semantic_sam_path = 'models/swinl_only_sam_many2many.pth'
    clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_name, pretrained=None)
    clip_model.load_state_dict(torch.load(clip_path), strict=False)

    if(dataset=='Replica'):
        last_idx = 2000
        with open(path+"configs/config_Replica.yaml", "r") as f:
            config = yaml.safe_load(f)
    else:
        last_idx = len(os.listdir(dataset_path+ scene + '/color/'))
        with open(path+"configs/config_ScanNet.yaml", "r") as f:
            config = yaml.safe_load(f)

    geometery_overlap_thr1 = config['geometery_overlap_thr1']
    geometery_overlap_thr2 = config['geometery_overlap_thr2']
    resolution = config['resolution']
    remove_thr = config['remove_thr']
    side_thr = config['side_thr']
    iou_remove_thr = config['iou_remove_thr']
    ratio_thr = config['ratio_thr']
    extension_ratio_hide = config['extension_ratio_hide']
    extension_ratio_s = config['extension_ratio_s']
    extension_ratio_l = config['extension_ratio_l']
    extension_ratio_h = config['extension_ratio_h']
    area_threshold = config['area_threshold']
    alpha_h = config['alpha_h']
    alpha_l = config['alpha_l']
    alpha_o = config['alpha_o']
    alpha_m = config['alpha_m']

    scene_path = dataset_path+scene+'/'
    if(dataset=='Replica'):
        camera_intristics, scale = _load_depth_intrinsics(path=dataset_path+'/cam_params.json',dataset=dataset)
    else:
        camera_intristics, scale = _load_depth_intrinsics(path=dataset_path+scene+'/intrinsic/intrinsic_depth.txt',dataset=dataset)

    colors, masks, points, embeddings = masks_detection(extension_ratio_hide=extension_ratio_hide,
                                            extension_ratio_s=extension_ratio_s,extension_ratio_l=extension_ratio_l,extension_ratio_h=extension_ratio_h,path=scene_path,
                                            last_idx=last_idx,step=15,remove_iou=iou_remove_thr,remove_side_thr=remove_thr,side_thr=side_thr,dataset=dataset,
                                            area_threshold=area_threshold,ratio_thr=ratio_thr,camera_intristics=camera_intristics,scale=scale,resolution=resolution,
                                            clip_model=clip_model,preprocess=preprocess,alpha_h=alpha_h,alpha_l=alpha_l,alpha_o=alpha_o,alpha_m=alpha_m,
                                                            semantic_sam_path=semantic_sam_path)

    adjacency = geometry_overlap(points,resolution,geometery_overlap_thr1,geometery_overlap_thr2)
    components = merge_points(adjacency)
    new_colors,new_points,new_embeddings,count = object_embeddings(components,points,embeddings,colors)

    df_points_to_ids = pd.DataFrame(columns=['x', 'y', 'z','Object id'])

    df_ids_to_embeddings = {}
    for i in range(len(new_points)):
        points = np.unique(new_points[i],axis=0)
        temp = pd.DataFrame(columns=['x', 'y', 'z','Object id'])
        temp['x'] = points[:,0]
        temp['y'] = points[:,1]
        temp['z'] = points[:,2]
        temp['Object id'] = i

        df_ids_to_embeddings[i] = {'embedding':list(new_embeddings[i].astype(float)),'count':count[i]}
        df_points_to_ids = pd.concat([df_points_to_ids,temp],ignore_index=True)
    # Save to CSV
    with open(path+'embeddings/'+scene + '_ids_to_embeddings.json', "w") as json_file:
        json.dump(df_ids_to_embeddings, json_file, indent=4)
    df_points_to_ids.to_csv(path+'embeddings/'+scene+'_points_to_ids.csv', index=False)


if __name__ == '__main__':
    main()





