from models.DETR.matcher import build_matcher
import torch.nn.functional as F
import torch
from torch import logit, nn
from criterion import Matcher
from utils.dist import all_reduce_average
from utils.box_util import generalized_box3d_iou


def cal_sim(z_i,z_j,temperature):
    z_i = z_i / z_i.norm(dim=len(z_i.shape)-1, keepdim=True)
    z_j = z_j / z_j.norm(dim=len(z_j.shape)-1, keepdim=True)
    return z_i @ z_j.t() / temperature

def CLIP_loss_with_dist(objs, dist ,base_temp=0.2):
    # Compute temperature that according to distance
    temp_scale = torch.pow(1/1.1, dist)
    temperature = temp_scale * base_temp
    
    device = objs.device
    clip_loss = torch.tensor(0,device=device,dtype=torch.float)
    criterion = nn.CrossEntropyLoss(reduction="mean").to(device=device)
    
    valid_obj_cnt = 1
    for obj_ind in range(objs.shape[0]):
    	obj = objs[obj_ind,:]
    	logits = []
    	obj_feature = obj[:-2]
    	obj_cls = obj[-2]
    	obj_score = obj[-1]
    	neg_objs_inds = torch.where(objs[:,-2]!=obj_cls)[0]
    	
    	if len(neg_objs_inds) > 0:
    		neg_objs = objs[neg_objs_inds,:]
    		neg_loss = cal_sim(obj_feature,neg_objs[:,:-2],base_temp).unsqueeze(0)
    	else:
    		continue
    	
    	pos_objs_inds = torch.where(objs[:,-2]==obj_cls)[0]
    	pos_objs_inds = [i for i in pos_objs_inds if i!= obj_ind]		# remove itself
    	
    	if len(pos_objs_inds) > 0:
    		#pos_loss = cal_sim(obj_feature,objs[pos_objs_inds,:-2], 1).unsqueeze(0).t()
    		pos_loss = cal_sim(obj_feature,objs[pos_objs_inds,:-2], temperature[obj_ind, pos_objs_inds]).unsqueeze(0).t()
    	else:
    		pos_loss = torch.tensor([1/base_temp],device=device,dtype=torch.float).unsqueeze(0)
    		valid_obj_cnt -= 1
    		#continue
    	
    	logits = torch.cat([pos_loss,neg_loss.repeat(pos_loss.shape[0],1)],dim=1)
    	labels = torch.zeros(logits.shape[0], device=device,dtype=torch.long)
    	
    	cur_loss = criterion(logits,labels)
    	clip_loss += cur_loss
    	valid_obj_cnt += 1
    
    #print("object num = %d"%(objs.shape[0]))
    #print("valid_obj_cnt num = %d"%(valid_obj_cnt))
    
    clip_loss /= valid_obj_cnt
    #print(clip_loss)
    return clip_loss


def CLIP_loss_position(objs, dist ,base_temp=0.2):
    # Compute temperature that according to distance
    temp_scale = torch.pow(1/1.1, dist)
    temperature = temp_scale * base_temp
    
    device = objs.device
    clip_loss = torch.tensor(0,device=device,dtype=torch.float)
    criterion = nn.CrossEntropyLoss(reduction="mean").to(device=device)
    
    valid_obj_cnt = 1
    for obj_ind in range(objs.shape[0]):
    	obj = objs[obj_ind,:]
    	logits = []
    	obj_feature = obj[:-2]
    	obj_cls = obj[-2]
    	obj_score = obj[-1]
    	
    	cur_dist = dist[obj_ind, :]
    	
    	# Modify Here to find negatives
    	neg_objs_inds = torch.where(cur_dist > 0.01)[0]
    	#neg_objs_inds = torch.where(objs[:,-2]!=obj_cls)[0]
    	
    	if len(neg_objs_inds) > 0:
    		neg_objs = objs[neg_objs_inds,:]
    		neg_loss = cal_sim(obj_feature,neg_objs[:,:-2],base_temp).unsqueeze(0)
    	else:
    		continue
    	
    	# Modify here to find positive
    	pos_objs_inds = torch.where(cur_dist < 0.01)[0]
    	pos_objs_inds = [i for i in pos_objs_inds if i!= obj_ind]		# remove itself
    	
    	if len(pos_objs_inds) > 0:
    		pos_loss = cal_sim(obj_feature,objs[pos_objs_inds,:-2], 1).unsqueeze(0).t()
    		#pos_loss = cal_sim(obj_feature,objs[pos_objs_inds,:-2], temperature[obj_ind, pos_objs_inds]).unsqueeze(0).t()
    	else:
    		pos_loss = torch.tensor([1/base_temp],device=device,dtype=torch.float).unsqueeze(0)
    		valid_obj_cnt -= 1
    	
    	
    	logits = torch.cat([pos_loss,neg_loss.repeat(pos_loss.shape[0],1)],dim=1)
    	labels = torch.zeros(logits.shape[0], device=device,dtype=torch.long)
    	
    	cur_loss = criterion(logits,labels)
    	clip_loss += cur_loss
    	valid_obj_cnt += 1
    
    clip_loss /= valid_obj_cnt
    print(clip_loss)
    return clip_loss


def CLIP_loss(objs, temperature=0.2):
    device = objs.device
    clip_loss = torch.tensor(0,device=device,dtype=torch.float)
    criterion = nn.CrossEntropyLoss(reduction="mean").to(device=device)
    
    valid_obj_cnt = 1
    for obj_ind in range(objs.shape[0]):
    	obj = objs[obj_ind,:]
    	logits = []
    	obj_feature = obj[:-2]
    	obj_cls = obj[-2]
    	obj_score = obj[-1]
    	neg_objs_inds = torch.where(objs[:,-2]!=obj_cls)[0]
    	
    	if len(neg_objs_inds) > 0:
    		neg_objs = objs[neg_objs_inds,:]
    		neg_loss = cal_sim(obj_feature,neg_objs[:,:-2],temperature).unsqueeze(0)
    	else:
    		continue
    	
    	pos_objs_inds = torch.where(objs[:,-2]==obj_cls)[0]
    	pos_objs_inds = [i for i in pos_objs_inds if i!= obj_ind]		# remove itself
    	
    	if len(pos_objs_inds) > 0:
    		pos_loss = cal_sim(obj_feature,objs[pos_objs_inds,:-2],temperature).unsqueeze(0).t()
    	else:
    		pos_loss = torch.tensor([1/base_temp],device=device,dtype=torch.float).unsqueeze(0)
    		valid_obj_cnt -= 1
    	
    	
    	logits = torch.cat([pos_loss,neg_loss.repeat(pos_loss.shape[0],1)],dim=1)
    	labels = torch.zeros(logits.shape[0], device=device,dtype=torch.long)
    	
    	cur_loss = criterion(logits,labels)
    	clip_loss += cur_loss
    	valid_obj_cnt += 1
    
    clip_loss /= valid_obj_cnt
    
    #print(clip_loss)
    return clip_loss
    

def clip_3_branch_dist_aware(args, img_outputs, pc_outputs, img_ground_truth, pc_ground_truth, imagenet_outputs):
    try:
    	assert torch.sum(torch.abs(torch.sum(pc_ground_truth["gt_box_present"], dim=1) - pc_ground_truth["pseudo_bbox_num"] - pc_ground_truth["bbox_num"])) == 0
    except:
    	print("Error: Image and Point Cloud do not match at Function: clip_3_branch_dist_aware")
    	exit()
    
    
    objs = []
    objs_pos = []
    # prepare image branch feat
    detr_matcher = build_matcher(args)
    device = img_outputs["img_query_feat"].device
    img_outputs_without_aux = {k: v for k, v in img_outputs.items() if k != 'aux_outputs'}

    # match prediction with gt
    img_indices = detr_matcher(img_outputs_without_aux, img_ground_truth)
    img_query_feat = img_outputs["img_query_feat"].permute(1,0,2)
    prob = F.softmax(img_outputs['pred_logits'], -1)
    scores, _ = prob[..., :-1].max(-1)
    
    # construct img objs [features,class,score] [256,1,1]
    for bs in range(len(img_indices)):
    	if len(img_indices[bs])>0 and len(img_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(img_indices[bs][0],img_indices[bs][1]):
    			img_obj = torch.zeros(img_query_feat.shape[2]+2, device=device)			# 256 + 2
    			cls = img_ground_truth[bs]["labels"][t_box_no]
    			img_obj[:-2] = img_query_feat[bs, box_no, :]
    			img_obj[-2] = cls
    			img_obj[-1] = scores[bs, box_no]
    			objs.append(img_obj)
    			
    			img_obj_pos = pc_ground_truth["gt_box_centers"][bs, t_box_no, :]
    			objs_pos.append(img_obj_pos)
    
    # prepare point cloud branch feat
    detr3_matcher = Matcher(
        cost_class=args.matcher_cls_cost,
        cost_giou=args.matcher_giou_cost,
        cost_center=args.matcher_center_cost,
        cost_objectness=args.matcher_objectness_cost,
    )

    nactual_gt = pc_ground_truth["gt_box_present"].sum(axis=1).long()
    num_boxes = torch.clamp(all_reduce_average(nactual_gt.sum()), min=1).item()
    pc_ground_truth["nactual_gt"] = nactual_gt
    pc_ground_truth["num_boxes"] = num_boxes
    pc_ground_truth[
        "num_boxes_replica"
    ] = nactual_gt.sum().item()  # number of boxes on this worker for dist training
        
    
    # Compute giou that used for match
    pc_output = pc_outputs["outputs"]
    gious = generalized_box3d_iou(
        pc_output["box_corners"],
        pc_ground_truth["gt_box_corners"],
        pc_ground_truth["nactual_gt"],
        rotated_boxes=torch.any(pc_ground_truth["gt_box_angles"] > 0).item(),
        needs_grad=False,
    )

    pc_output["gious"] = gious
    
    center_dist = torch.cdist(
        pc_output["center_normalized"], pc_ground_truth["gt_box_centers_normalized"], p=1
    )

    pc_output["center_dist"] = center_dist
    pc_indices = detr3_matcher(pc_output, pc_ground_truth)
    pc_indices = pc_indices["assignments"]

    pc_query_feat = pc_outputs["pc_query_feat"]
    pc_query_feat = pc_query_feat.permute(1,0,2)
    for bs in range(len(pc_indices)):
    	if len(pc_indices[bs])>0 and len(pc_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(pc_indices[bs][0],pc_indices[bs][1]):
    			cls = pc_ground_truth["gt_box_sem_cls_label"][bs, t_box_no]
    			pc_obj = torch.zeros(pc_query_feat.shape[2]+2, device=device)
    			pc_obj[:-2] = pc_query_feat[bs, box_no, :]
    			pc_obj[-2] = cls
    			pc_obj[-1] = pc_output["objectness_prob"][bs, box_no]
    			objs.append(pc_obj)
    			
    			pc_obj_pos = pc_ground_truth["gt_box_centers"][bs, t_box_no, :]
    			objs_pos.append(pc_obj_pos)
    			
    
    
    imagenet_query_feat = imagenet_outputs["img_query_feat"].permute(1,0,2)
    imagenet_pred_bbox = imagenet_outputs["pred_boxes"]
    
    obj_w = imagenet_pred_bbox[:, :, 2]
    obj_h = imagenet_pred_bbox[:, :, 3]
    obj_size = obj_w * obj_h
    selected_ind = torch.argmax(obj_size, dim=1)
    #print(selected_ind)
    
    prob = imagenet_outputs["pred_logits"]
    prob = F.softmax(prob, -1)
    
    scores, _ = prob[..., :-1].max(-1)
    imagenet_label = pc_ground_truth["bboxes_2d_label_imagenet"]
    #print(imagenet_label[:,0])
    #print(scores)
    #input()
    
    for bs in range(imagenet_query_feat.shape[0]):
    	imagenet_obj = torch.zeros(imagenet_query_feat.shape[2]+2, device=device)
    	imagenet_obj[:-2] = imagenet_query_feat[bs, selected_ind[bs], :]
    	imagenet_obj[-2] = imagenet_label[bs, 0]
    	imagenet_obj[-1] = scores[bs, selected_ind[bs]]
    	objs.append(imagenet_obj)
    	
    if len(objs)>0:
        objs = torch.stack(objs,0)
        objs_pos = torch.stack(objs_pos,0)
        
        obj_num = objs.shape[0]
        obj_with_pos_num = objs_pos.shape[0]
        
        obj_dist = 1 * torch.ones([obj_num, obj_num], device=device)
        obj_dist[:obj_with_pos_num, :obj_with_pos_num] = torch.cdist(objs_pos, objs_pos, p=2)
        return CLIP_loss_with_dist(objs, obj_dist)
    else:
        # Avoid gradient of Nan
        loss_pc = torch.sum(pc_query_feat)
        loss_img = torch.sum(img_query_feat)
        loss = 0 * loss_pc + 0 * loss_img
        return loss


# batch_data_label = pc_ground_truth
def clip_pc_img(args, img_outputs, pc_outputs, img_ground_truth, pc_ground_truth):
    objs = []
    # prepare image branch feat
    detr_matcher = build_matcher(args)
    device = img_outputs["img_query_feat"].device
    img_outputs_without_aux = {k: v for k, v in img_outputs.items() if k != 'aux_outputs'}

    # match prediction with gt
    img_indices = detr_matcher(img_outputs_without_aux, img_ground_truth)
    img_query_feat = img_outputs["img_query_feat"].permute(1,0,2)
    prob = F.softmax(img_outputs['pred_logits'], -1)
    scores, _ = prob[..., :-1].max(-1)
    
    # construct img objs [features,class,score] [256,1,1]
    for bs in range(len(img_indices)):
    	if len(img_indices[bs])>0 and len(img_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(img_indices[bs][0],img_indices[bs][1]):
    			img_obj = torch.zeros(img_query_feat.shape[2]+2, device=device)			# 256 + 2
    			cls = img_ground_truth[bs]["labels"][t_box_no]
    			img_obj[:-2] = img_query_feat[bs, box_no, :]
    			img_obj[-2] = cls
    			img_obj[-1] = scores[bs, box_no]
    			objs.append(img_obj)
    
    # prepare point cloud branch feat
    detr3_matcher = Matcher(
        cost_class=args.matcher_cls_cost,
        cost_giou=args.matcher_giou_cost,
        cost_center=args.matcher_center_cost,
        cost_objectness=args.matcher_objectness_cost,
    )

    nactual_gt = pc_ground_truth["gt_box_present"].sum(axis=1).long()
    num_boxes = torch.clamp(all_reduce_average(nactual_gt.sum()), min=1).item()
    pc_ground_truth["nactual_gt"] = nactual_gt
    pc_ground_truth["num_boxes"] = num_boxes
    pc_ground_truth[
        "num_boxes_replica"
    ] = nactual_gt.sum().item()  # number of boxes on this worker for dist training
        
    
    # Compute giou that used for match
    pc_output = pc_outputs["outputs"]
    gious = generalized_box3d_iou(
        pc_output["box_corners"],
        pc_ground_truth["gt_box_corners"],
        pc_ground_truth["nactual_gt"],
        rotated_boxes=torch.any(pc_ground_truth["gt_box_angles"] > 0).item(),
        needs_grad=False,
    )

    pc_output["gious"] = gious
    
    center_dist = torch.cdist(
        pc_output["center_normalized"], pc_ground_truth["gt_box_centers_normalized"], p=1
    )

    pc_output["center_dist"] = center_dist
    pc_indices = detr3_matcher(pc_output, pc_ground_truth)
    pc_indices = pc_indices["assignments"]

    pc_query_feat = pc_outputs["pc_query_feat"]
    pc_query_feat = pc_query_feat.permute(1,0,2)
    for bs in range(len(pc_indices)):
    	if len(pc_indices[bs])>0 and len(pc_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(pc_indices[bs][0],pc_indices[bs][1]):
    			cls = pc_ground_truth["gt_box_sem_cls_label"][bs, t_box_no]
    			pc_obj = torch.zeros(pc_query_feat.shape[2]+2, device=device)
    			pc_obj[:-2] = pc_query_feat[bs, box_no, :]
    			pc_obj[-2] = cls
    			pc_obj[-1] = pc_output["objectness_prob"][bs, box_no]
    			objs.append(pc_obj)
    
    if len(objs)>0:
        objs = torch.stack(objs,0)
        #print(objs.shape)
        return CLIP_loss(objs=objs)
    else:
        # Avoid gradient of Nan
        loss_pc = torch.sum(pc_query_feat)
        loss_img = torch.sum(img_query_feat)
        loss = 0 * loss_pc + 0 * loss_img
        return loss
        
# batch_data_label = pc_ground_truth
def clip_3_branch(args, img_outputs, pc_outputs, img_ground_truth, pc_ground_truth, imagenet_outputs):
    objs = []
    # prepare image branch feat
    detr_matcher = build_matcher(args)
    device = img_outputs["img_query_feat"].device
    img_outputs_without_aux = {k: v for k, v in img_outputs.items() if k != 'aux_outputs'}

    # match prediction with gt
    img_indices = detr_matcher(img_outputs_without_aux, img_ground_truth)
    img_query_feat = img_outputs["img_query_feat"].permute(1,0,2)
    prob = F.softmax(img_outputs['pred_logits'], -1)
    scores, _ = prob[..., :-1].max(-1)
    
    # construct img objs [features,class,score] [256,1,1]
    for bs in range(len(img_indices)):
    	if len(img_indices[bs])>0 and len(img_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(img_indices[bs][0],img_indices[bs][1]):
    			img_obj = torch.zeros(img_query_feat.shape[2]+2, device=device)			# 256 + 2
    			cls = img_ground_truth[bs]["labels"][t_box_no]
    			img_obj[:-2] = img_query_feat[bs, box_no, :]
    			img_obj[-2] = cls
    			img_obj[-1] = scores[bs, box_no]
    			objs.append(img_obj)
    
    # prepare point cloud branch feat
    detr3_matcher = Matcher(
        cost_class=args.matcher_cls_cost,
        cost_giou=args.matcher_giou_cost,
        cost_center=args.matcher_center_cost,
        cost_objectness=args.matcher_objectness_cost,
    )

    nactual_gt = pc_ground_truth["gt_box_present"].sum(axis=1).long()
    num_boxes = torch.clamp(all_reduce_average(nactual_gt.sum()), min=1).item()
    pc_ground_truth["nactual_gt"] = nactual_gt
    pc_ground_truth["num_boxes"] = num_boxes
    pc_ground_truth[
        "num_boxes_replica"
    ] = nactual_gt.sum().item()  # number of boxes on this worker for dist training
        
    
    # Compute giou that used for match
    pc_output = pc_outputs["outputs"]
    gious = generalized_box3d_iou(
        pc_output["box_corners"],
        pc_ground_truth["gt_box_corners"],
        pc_ground_truth["nactual_gt"],
        rotated_boxes=torch.any(pc_ground_truth["gt_box_angles"] > 0).item(),
        needs_grad=False,
    )

    pc_output["gious"] = gious
    
    center_dist = torch.cdist(
        pc_output["center_normalized"], pc_ground_truth["gt_box_centers_normalized"], p=1
    )

    pc_output["center_dist"] = center_dist
    pc_indices = detr3_matcher(pc_output, pc_ground_truth)
    pc_indices = pc_indices["assignments"]

    pc_query_feat = pc_outputs["pc_query_feat"]
    pc_query_feat = pc_query_feat.permute(1,0,2)
    for bs in range(len(pc_indices)):
    	if len(pc_indices[bs])>0 and len(pc_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(pc_indices[bs][0],pc_indices[bs][1]):
    			cls = pc_ground_truth["gt_box_sem_cls_label"][bs, t_box_no]
    			pc_obj = torch.zeros(pc_query_feat.shape[2]+2, device=device)
    			pc_obj[:-2] = pc_query_feat[bs, box_no, :]
    			pc_obj[-2] = cls
    			pc_obj[-1] = pc_output["objectness_prob"][bs, box_no]
    			objs.append(pc_obj)
    
    
    imagenet_query_feat = imagenet_outputs["img_query_feat"].permute(1,0,2)
    prob = imagenet_outputs["pred_logits"]
    prob = F.softmax(prob, -1)
    
    scores, _ = prob[..., :-1].max(-1)
    imagenet_label = pc_ground_truth["bboxes_2d_label_imagenet"]
    
    for bs in range(imagenet_query_feat.shape[0]):
    	imagenet_obj = torch.zeros(imagenet_query_feat.shape[2]+2, device=device)
    	imagenet_obj[:-2] = imagenet_query_feat[bs, 0, :]
    	imagenet_obj[-2] = imagenet_label[bs, 0]
    	imagenet_obj[-1] = scores[bs, 0]
    	objs.append(imagenet_obj)
    	
    if len(objs)>0:
        objs = torch.stack(objs,0)
        #print(objs.shape)
        return CLIP_loss(objs=objs)
    else:
        # Avoid gradient of Nan
        loss_pc = torch.sum(pc_query_feat)
        loss_img = torch.sum(img_query_feat)
        loss = 0 * loss_pc + 0 * loss_img
        return loss 
        

# Only paired image and point cloud has position
def position_clip(args, img_outputs, pc_outputs, img_ground_truth, pc_ground_truth):
    try:
    	assert torch.sum(torch.abs(torch.sum(pc_ground_truth["gt_box_present"], dim=1) - pc_ground_truth["pseudo_bbox_num"] - pc_ground_truth["bbox_num"])) == 0
    except:
    	print("Error: Image and Point Cloud do not match at Function: clip_3_branch_dist_aware")
    	exit()
    
    
    objs = []
    objs_pos = []
    # prepare image branch feat
    detr_matcher = build_matcher(args)
    device = img_outputs["img_query_feat"].device
    img_outputs_without_aux = {k: v for k, v in img_outputs.items() if k != 'aux_outputs'}

    # match prediction with gt
    img_indices = detr_matcher(img_outputs_without_aux, img_ground_truth)
    img_query_feat = img_outputs["img_query_feat"].permute(1,0,2)
    prob = F.softmax(img_outputs['pred_logits'], -1)
    scores, _ = prob[..., :-1].max(-1)
    
    # construct img objs [features,class,score] [256,1,1]
    for bs in range(len(img_indices)):
    	if len(img_indices[bs])>0 and len(img_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(img_indices[bs][0],img_indices[bs][1]):
    			img_obj = torch.zeros(img_query_feat.shape[2]+2, device=device)			# 256 + 2
    			cls = img_ground_truth[bs]["labels"][t_box_no]
    			img_obj[:-2] = img_query_feat[bs, box_no, :]
    			img_obj[-2] = cls
    			img_obj[-1] = scores[bs, box_no]
    			objs.append(img_obj)
    			
    			img_obj_pos = pc_ground_truth["gt_box_centers"][bs, t_box_no, :]
    			objs_pos.append(img_obj_pos)
    
    # prepare point cloud branch feat
    detr3_matcher = Matcher(
        cost_class=args.matcher_cls_cost,
        cost_giou=args.matcher_giou_cost,
        cost_center=args.matcher_center_cost,
        cost_objectness=args.matcher_objectness_cost,
    )

    nactual_gt = pc_ground_truth["gt_box_present"].sum(axis=1).long()
    num_boxes = torch.clamp(all_reduce_average(nactual_gt.sum()), min=1).item()
    pc_ground_truth["nactual_gt"] = nactual_gt
    pc_ground_truth["num_boxes"] = num_boxes
    pc_ground_truth[
        "num_boxes_replica"
    ] = nactual_gt.sum().item()  # number of boxes on this worker for dist training
        
    
    # Compute giou that used for match
    pc_output = pc_outputs["outputs"]
    gious = generalized_box3d_iou(
        pc_output["box_corners"],
        pc_ground_truth["gt_box_corners"],
        pc_ground_truth["nactual_gt"],
        rotated_boxes=torch.any(pc_ground_truth["gt_box_angles"] > 0).item(),
        needs_grad=False,
    )

    pc_output["gious"] = gious
    
    center_dist = torch.cdist(
        pc_output["center_normalized"], pc_ground_truth["gt_box_centers_normalized"], p=1
    )

    pc_output["center_dist"] = center_dist
    pc_indices = detr3_matcher(pc_output, pc_ground_truth)
    pc_indices = pc_indices["assignments"]

    pc_query_feat = pc_outputs["pc_query_feat"]
    pc_query_feat = pc_query_feat.permute(1,0,2)
    for bs in range(len(pc_indices)):
    	if len(pc_indices[bs])>0 and len(pc_indices[bs][0])>0:
    		for (box_no,t_box_no) in zip(pc_indices[bs][0],pc_indices[bs][1]):
    			cls = pc_ground_truth["gt_box_sem_cls_label"][bs, t_box_no]
    			pc_obj = torch.zeros(pc_query_feat.shape[2]+2, device=device)
    			pc_obj[:-2] = pc_query_feat[bs, box_no, :]
    			pc_obj[-2] = cls
    			pc_obj[-1] = pc_output["objectness_prob"][bs, box_no]
    			objs.append(pc_obj)
    			
    			pc_obj_pos = pc_ground_truth["gt_box_centers"][bs, t_box_no, :]
    			objs_pos.append(pc_obj_pos)

    	
    if len(objs)>0:
        objs = torch.stack(objs,0)
        objs_pos = torch.stack(objs_pos,0)
        
        obj_num = objs.shape[0]
        obj_with_pos_num = objs_pos.shape[0]
        
        obj_dist = 1 * torch.ones([obj_num, obj_num], device=device)
        obj_dist[:obj_with_pos_num, :obj_with_pos_num] = torch.cdist(objs_pos, objs_pos, p=2)
        return CLIP_loss_position(objs, obj_dist)
    else:
        # Avoid gradient of Nan
        loss_pc = torch.sum(pc_query_feat)
        loss_img = torch.sum(img_query_feat)
        loss = 0 * loss_pc + 0 * loss_img
        return loss