'''
File Created: Monday, 25th November 2019 1:35:30 pm
Author: Dave Zhenyu Chen (zhenyu.chen@tum.de)
'''

import os
import sys
import time
import h5py
import json
import pickle
import random
import numpy as np
import multiprocessing as mp
import torch
import os.path as osp
import math
import scipy.interpolate
import scipy.ndimage
import clip
import open3d as o3d
import imageio
import gc
import sng_parser
from torch.utils.data import Dataset
from glob import glob
from tqdm import tqdm


from config.config import CONF
from utils.pc_utils import random_sampling, rotx, roty, rotz
from data.scannet.model_util_scannet import rotate_aligned_boxes, ScannetDatasetConfig, rotate_aligned_boxes_along_axis
from models.softgroup.ops import voxelization_idx
from models.long_clip.model import longclip
# data setting
DC = ScannetDatasetConfig()

# data path
SCANNET_V2_TSV = os.path.join(CONF.PATH.SCANNET_META, "scannetv2-labels.combined.tsv")

class ReferenceDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, idx):
        raise NotImplementedError

    def _get_raw2label(self):
        # mapping
        scannet_labels = DC.type2class.keys()
        scannet2label = {label: i for i, label in enumerate(scannet_labels)}

        lines = [line.rstrip() for line in open(SCANNET_V2_TSV)]
        lines = lines[1:]
        raw2label = {}
        for i in range(len(lines)):
            label_classes_set = set(scannet_labels)
            elements = lines[i].split('\t')
            raw_name = elements[1]
            nyu40_name = elements[7]
            # print(lines[i])
            # print(raw_name, nyu40_name)
            if nyu40_name not in label_classes_set:
                raw2label[raw_name] = scannet2label['others']
            else:
                raw2label[raw_name] = scannet2label[nyu40_name]
        return raw2label

    def _get_unique_multiple_lookup(self):
        all_sem_labels = {}
        cache = {}
        for data in self.scanrefer:
            scene_id = data["scene_id"]
            object_id = data["object_id"]
            object_name = " ".join(data["object_name"].split("_"))
            ann_id = data["ann_id"]

            if scene_id not in all_sem_labels:
                all_sem_labels[scene_id] = []

            if scene_id not in cache:
                cache[scene_id] = {}

            if object_id not in cache[scene_id]:
                cache[scene_id][object_id] = {}
                try:
                    all_sem_labels[scene_id].append(self.raw2label[object_name])
                except KeyError:
                    all_sem_labels[scene_id].append(17)

        # convert to numpy array
        all_sem_labels = {scene_id: np.array(all_sem_labels[scene_id]) for scene_id in all_sem_labels.keys()}

        unique_multiple_lookup = {}
        for data in self.scanrefer:
            scene_id = data["scene_id"]
            object_id = data["object_id"]
            object_name = " ".join(data["object_name"].split("_"))
            ann_id = data["ann_id"]

            try:
                sem_label = self.raw2label[object_name]
            except KeyError:
                sem_label = 17

            unique_multiple = 0 if (all_sem_labels[scene_id] == sem_label).sum() == 1 else 1

            # store
            if scene_id not in unique_multiple_lookup:
                unique_multiple_lookup[scene_id] = {}

            if object_id not in unique_multiple_lookup[scene_id]:
                unique_multiple_lookup[scene_id][object_id] = {}

            if ann_id not in unique_multiple_lookup[scene_id][object_id]:
                unique_multiple_lookup[scene_id][object_id][ann_id] = None

            unique_multiple_lookup[scene_id][object_id][ann_id] = unique_multiple

        return unique_multiple_lookup

    def _load_data(self, dataset_name):
        print("loading data...")
        self.raw2label = self._get_raw2label()

        if not CONF.pretrain_model_on:
            pretrained_path = CONF.PATH.PRETRAINED_TRAIN_DATA if self.split == "train" else CONF.PATH.PRETRAINED_VAL_DATA
            self.pretrained_data = torch.load(pretrained_path)

        # add scannet data
        print("loading point cloud data...")
        self.scene_list = sorted(list(set([data["scene_id"] for data in self.scanrefer])))


class ScannetReferenceDataset(ReferenceDataset):
       
    def __init__(self, 
        scanrefer, 
        scanrefer_new, 
        scanrefer_all_scene,
        data_root,
        prefix,
        suffix,
        voxel_cfg=None,
        training=True,
        with_label=True,
        repeat=1,
        name="ScanRefer",
        lang_num_max=8,
        augment=False,
        shuffle=False,
        scan2cad_rotation=None):

        # NOTE only feed the scan2cad_rotation when on the training mode and train split

        self.scanrefer = scanrefer
        self.scanrefer_new = scanrefer_new
        self.scanrefer_new_len = len(scanrefer_new)
        self.scanrefer_all_scene = scanrefer_all_scene # all scene_ids in scanrefer
        self.name = name
        self.lang_num_max = lang_num_max
        self.augment = augment
        self.scan2cad_rotation = scan2cad_rotation

        self.data_root = data_root
        self.prefix = prefix
        self.suffix = suffix
        self.voxel_cfg = voxel_cfg
        self.training = training
        self.with_label = with_label
        self.repeat = repeat
        self.mode = 'train' if training else 'test'
        # load data
        self._load_data(name)
        self.multiview_data = {}
        self.should_shuffle = shuffle
        self.depth_scale = 1000.0
        self.vis_threshold = 0.05
        self.scene_graphs = self.load_scene_graphs()

    def load(self, filename):
        return torch.load(filename)
       
    def __len__(self):
        #return len(self.scanrefer)
        return self.scanrefer_new_len

    def load_scene_graphs(self):
        scene_graphs = {}
        if self.name == 'ScanRefer':
            scene_graphs_path = os.path.join(CONF.PATH.DATA, "scanrefer", 'scene_graphs_new_'+self.prefix+'.json')
        else:
            scene_graphs_path = os.path.join(CONF.PATH.DATA, self.name, 'scene_graphs_'+self.prefix+'.json')
        # scene_graphs_path = os.path.join(self.data_root, self.dataset, 'scene_graphs_'+self.prefix+'.json')
        if os.path.exists(scene_graphs_path):
            scene_graphs = json.load(open(scene_graphs_path))
        else:
            print('Begin '+ self.prefix +' text decoupling (scene graphs)...')
            for data in tqdm(self.scanrefer):
                scene_id = str(data['scene_id'])
                ann_id = str(data['ann_id'])
                
                if scene_id not in scene_graphs:
                    scene_graphs[scene_id] = {}
                if ann_id not in scene_graphs[scene_id]:
                    scene_graphs[scene_id][ann_id] = {}

                if self.name == 'Scanrefer':
                    scene_graphs[scene_id][ann_id] = Scene_graph_parse(' '.join(data['token']))
                elif self.name == 'multi3drefer':
                    scene_graphs[scene_id][ann_id] = Scene_graph_parse(data["description"])
                elif self.name == 'nr3d':
                    scene_graphs[scene_id][ann_id] = Scene_graph_parse(' '.join(data['token']))
                elif self.name == 'sr3d':
                    scene_graphs[scene_id][ann_id] = Scene_graph_parse(' '.join(data['token']))
            
            print('Saving '+ self.prefix +' text decoupling (scene graphs)...')
            json.dump(scene_graphs, open(scene_graphs_path, 'w'))

            print('Done '+ self.prefix +' text decoupling (scene graphs).')
        
        return scene_graphs

    def split_scene_new(self,  scanrefer_data):
        scanrefer_train_new = []
        scanrefer_train_new_scene, scanrefer_train_scene = [], []
        scene_id = ''
        lang_num_max = self.lang_num_max
        for data in scanrefer_data:
            if scene_id != data["scene_id"]:
                scene_id = data["scene_id"]
                if len(scanrefer_train_scene) > 0:
                    if self.should_shuffle:
                        random.shuffle(scanrefer_train_scene)
                    # print("scanrefer_train_scene", len(scanrefer_train_scene))
                    for new_data in scanrefer_train_scene:
                        if len(scanrefer_train_new_scene) >= lang_num_max:
                            scanrefer_train_new.append(scanrefer_train_new_scene)
                            scanrefer_train_new_scene = []
                        scanrefer_train_new_scene.append(new_data)
                    if len(scanrefer_train_new_scene) > 0:
                        scanrefer_train_new.append(scanrefer_train_new_scene)
                        scanrefer_train_new_scene = []
                    scanrefer_train_scene = []
            scanrefer_train_scene.append(data)
        if len(scanrefer_train_scene) > 0:
            if self.should_shuffle:
                random.shuffle(scanrefer_train_scene)
            # print("scanrefer_train_scene", len(scanrefer_train_scene))
            for new_data in scanrefer_train_scene:
                if len(scanrefer_train_new_scene) >= lang_num_max:
                    scanrefer_train_new.append(scanrefer_train_new_scene)
                    scanrefer_train_new_scene = []
                scanrefer_train_new_scene.append(new_data)
            if len(scanrefer_train_new_scene) > 0:
                scanrefer_train_new.append(scanrefer_train_new_scene)
                scanrefer_train_new_scene = []
        return scanrefer_train_new

    def shuffle_data(self):
        # if self.shuffled:
        #     return
        # SElf.shuffled = True
        print('shuffle dataset data(lang)', flush=True)
        self.scanrefer_new = self.split_scene_new(self.scanrefer)
        if self.should_shuffle:
            random.shuffle(self.scanrefer_new)
        assert len(self.scanrefer_new) == self.scanrefer_new_len, 'assert scanrefer length right'
        print('shuffle done', flush=True)

    def dataAugment(self, xyz, jitter=False, flip=False, rot=False, scale=False, prob=1.0):
        m = np.eye(3)
        if jitter and np.random.rand() < prob:
            m += np.random.randn(3, 3) * 0.1
        if flip and np.random.rand() < prob:
            m[0][0] *= np.random.randint(0, 2) * 2 - 1
        if rot and np.random.rand() < prob:
            theta = np.random.rand() * 2 * math.pi
            m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
                              [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]])

        else:
            # Empirically, slightly rotate the scene can match the results from checkpoint
            theta = 0.35 * math.pi
            m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
                              [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]])
        if scale and np.random.rand() < prob:
            scale_factor = np.random.uniform(0.95, 1.05)
            xyz = xyz * scale_factor
        return np.matmul(xyz, m)
    
    def elastic(self, x, gran, mag):
        blur0 = np.ones((3, 1, 1)).astype('float32') / 3
        blur1 = np.ones((1, 3, 1)).astype('float32') / 3
        blur2 = np.ones((1, 1, 3)).astype('float32') / 3

        bb = np.abs(x).max(0).astype(np.int32) // gran + 3
        noise = [np.random.randn(bb[0], bb[1], bb[2]).astype('float32') for _ in range(3)]
        noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
        ax = [np.linspace(-(b - 1) * gran, (b - 1) * gran, b) for b in bb]
        interp = [
            scipy.interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0)
            for n in noise
        ]

        def g(x_):
            return np.hstack([i(x_)[:, None] for i in interp])

        return x + g(x) * mag

    def crop(self, xyz, step=32):
        xyz_offset = xyz.copy()
        valid_idxs = xyz_offset.min(1) >= 0
        assert valid_idxs.sum() == xyz.shape[0]
        spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
        room_range = xyz.max(0) - xyz.min(0)
        while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
            step_temp = step
            if valid_idxs.sum() > 1e6:
                step_temp = step * 2
            offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
            xyz_offset = xyz + offset
            valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
            spatial_shape[:2] -= step_temp
        return xyz_offset, valid_idxs

    def getInstanceInfo(self, xyz, instance_label, semantic_label):
        # pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0
        instance_pointnum = []
        instance_cls = []
        # max(instance_num, 0) to support instance_label with no valid instance_id
        instance_num = max(int(instance_label.max()) + 1, 0)
        for i_ in range(instance_num):
            inst_idx_i = np.where(instance_label == i_)
            # xyz_i = xyz[inst_idx_i]
            # pt_mean[inst_idx_i] = xyz_i.mean(0)
            instance_pointnum.append(inst_idx_i[0].size)
            cls_idx = inst_idx_i[0][0]
            instance_cls.append(semantic_label[cls_idx])
        # pt_offset_label = pt_mean - xyz
        # return instance_num, instance_pointnum, instance_cls, pt_offset_label
        return instance_num, instance_pointnum, instance_cls
    
    
    def getCroppedInstLabel(self, instance_label, valid_idxs):
        instance_label = instance_label[valid_idxs]
        j = 0
        while (j < instance_label.max()):
            if (len(np.where(instance_label == j)[0]) == 0):
                instance_label[instance_label == instance_label.max()] = j
            j += 1
        return instance_label

    def transform_train(self, xyz, rgb, semantic_label, instance_label, aug_prob=1.0):
        xyz_middle = self.dataAugment(xyz, True, True, True, aug_prob)
        xyz = xyz_middle * self.voxel_cfg.scale
        if np.random.rand() < aug_prob:
            xyz = self.elastic(xyz, 6, 40.)
            xyz = self.elastic(xyz, 20, 160.)
        # xyz_middle = xyz / self.voxel_cfg.scale
        xyz = xyz - xyz.min(0)
        max_tries = 5
        while (max_tries > 0):
            xyz_offset, valid_idxs = self.crop(xyz)
            if valid_idxs.sum() >= self.voxel_cfg.min_npoint:
                xyz = xyz_offset
                break
            max_tries -= 1
        if valid_idxs.sum() < self.voxel_cfg.min_npoint:
            return None
        xyz = xyz[valid_idxs]
        xyz_middle = xyz_middle[valid_idxs]
        rgb = rgb[valid_idxs]
        semantic_label = semantic_label[valid_idxs]
        instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
        return xyz, xyz_middle, rgb, semantic_label, instance_label, valid_idxs

    def transform_test(self, xyz, rgb, semantic_label, instance_label):
        xyz_middle = self.dataAugment(xyz, False, False, False, False)
        xyz = xyz_middle * self.voxel_cfg.scale
        xyz -= xyz.min(0)
        valid_idxs = np.ones(xyz.shape[0], dtype=bool)
        instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
        return xyz, xyz_middle, rgb, semantic_label, instance_label, valid_idxs
    
    def get_ref_mask(self, instance_label, object_ids):
        ref_lbl = torch.zeros_like(instance_label)
        for obj_id in object_ids:
            ref_lbl[instance_label == obj_id] = 1
        gt_mask = ref_lbl.float()
        return gt_mask

    def load_depth_maps(self, depth_maps_paths):
        depth_maps = []
        for depth_map_path_i in depth_maps_paths:
            depth_path = os.path.join(depth_map_path_i)
            depth_maps.append(torch.from_numpy(imageio.imread(depth_path) / self.depth_scale).cuda())
        return torch.stack(depth_maps)
    
    def adjust_intrinsic(self, intrinsic, original_resolution, new_resolution):
        if original_resolution == new_resolution:
            return intrinsic
        
        resize_width = int(math.floor(new_resolution[1] * float(
                        original_resolution[0]) / float(original_resolution[1])))
        
        adapted_intrinsic = intrinsic.copy()
        adapted_intrinsic[0, 0] *= float(resize_width) / float(original_resolution[0])
        adapted_intrinsic[1, 1] *= float(new_resolution[1]) / float(original_resolution[1])
        adapted_intrinsic[0, 2] *= float(new_resolution[0] - 1) / float(original_resolution[0] - 1)
        adapted_intrinsic[1, 2] *= float(new_resolution[1] - 1) / float(original_resolution[1] - 1)
        return adapted_intrinsic
    
    def get_points_visibility(self, points, scene_id):
        num_points = points.shape[0]
        # get camera, img
        path_2_scene = osp.join(CONF.PATH.DATA, 'processed', scene_id)

        # extrinsics matrix
        path_2_poses = osp.join(path_2_scene, "pose")
        num_frames = len(os.listdir(path_2_poses)) * 10
        poses = [osp.join(path_2_poses, f"{i}.txt") for i in list(range(num_frames))[::10]]
        extrinsics = torch.linalg.inv(torch.from_numpy(np.stack([np.loadtxt(pose) for pose in poses])).cuda())

        # depth
        path_2_depth = osp.join(path_2_scene,"depth")
        depth_maps_paths = [osp.join(path_2_depth, f"{i}.png") for i in list(range(num_frames))[::10]]
        depth_maps = self.load_depth_maps(depth_maps_paths)

        # rgb img
        path_2_color = osp.join(path_2_scene,"color",'0.jpg')

        # height & width
        image_resolution = imageio.imread(path_2_color).shape[:2]
        depth_resolution = imageio.imread(list(depth_maps_paths)[0]).shape
        height = depth_resolution[0]
        width = depth_resolution[1]

        # intrinsic
        path_2_intrinsics = osp.join(path_2_scene, "intrinsic", "intrinsic_color.txt")
        intrinsic = torch.from_numpy(self.adjust_intrinsic(np.loadtxt(path_2_intrinsics), image_resolution, depth_resolution)).cuda()
        # intrinsics = torch.from_numpy(np.stack([intrinsic for frame_id in range(len(poses))])).cuda()
        intrinsics = intrinsic.repeat(len(poses), 1, 1)
        points = points.cuda()
        word2cam_mat = torch.einsum('bij, jk -> bik', torch.einsum('bij,bjk -> bik', intrinsics, extrinsics), points.T).permute(0,2,1)
        del intrinsic, extrinsics, points
        gc.collect()
        torch.cuda.empty_cache()

        if num_points > 500000:
            word2cam_mat_ = word2cam_mat.cpu()
            depth_maps_ = depth_maps.cpu()
            size = (word2cam_mat_.shape[0], word2cam_mat_.shape[1])
            mask = (word2cam_mat_[:, :, 2] != 0).reshape(size[0]*size[1])
            projected_points = torch.stack([(word2cam_mat_[:, :, 0].reshape(size[0]*size[1])[mask]/word2cam_mat_[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size), 
                            (word2cam_mat_[:, :, 1].reshape(size[0]*size[1])[mask]/word2cam_mat_[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size)]).permute(1,2,0).long()
            inside_mask = ((projected_points[:,:,0] < width)*(projected_points[:,:,0] > 0)*(projected_points[:,:,1] < height)*(projected_points[:,:,1] >0) == 1)
            point_depth = word2cam_mat_[:, :, 2]
            num_frames = depth_maps_.shape[0]
            for frame_id in range(num_frames):
                points_in_frame_mask = inside_mask[frame_id].clone()
                points_in_frame = (projected_points[frame_id][points_in_frame_mask])
                depth_in_frame = point_depth[frame_id][points_in_frame_mask]
                visibility_mask = (torch.abs(depth_maps_[frame_id][points_in_frame[:,1].long(), points_in_frame[:,0].long()]
                                            - depth_in_frame) <= \
                                            self.vis_threshold)
                inside_mask[frame_id][points_in_frame_mask] = visibility_mask
            
            del word2cam_mat, depth_maps
            gc.collect()
            torch.cuda.empty_cache()
            inside_mask = inside_mask.cuda()
        
        else:
            size = (word2cam_mat.shape[0], word2cam_mat.shape[1])
            mask = (word2cam_mat[:, :, 2] != 0).reshape(size[0]*size[1])
            projected_points = torch.stack([(word2cam_mat[:, :, 0].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size), 
                            (word2cam_mat[:, :, 1].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size)]).permute(1,2,0).long()
            inside_mask = ((projected_points[:,:,0] < width)*(projected_points[:,:,0] > 0)*(projected_points[:,:,1] < height)*(projected_points[:,:,1] >0) == 1)
            point_depth = word2cam_mat[:, :, 2]
            num_frames = depth_maps.shape[0]
            for frame_id in range(num_frames):
                points_in_frame_mask = inside_mask[frame_id].clone()
                points_in_frame = (projected_points[frame_id][points_in_frame_mask])
                depth_in_frame = point_depth[frame_id][points_in_frame_mask]
                visibility_mask = (torch.abs(depth_maps[frame_id][points_in_frame[:,1].long(), points_in_frame[:,0].long()]
                                            - depth_in_frame) <= \
                                            self.vis_threshold)
                inside_mask[frame_id][points_in_frame_mask] = visibility_mask
            
            del word2cam_mat, projected_points, mask, depth_maps
            gc.collect()
            torch.cuda.empty_cache()
        
        return inside_mask

    def __getitem__(self, idx):
        start = time.time()
        data_dict = {}
        lang_num = len(self.scanrefer_new[idx])
        scene_id = self.scanrefer_new[idx][0]["scene_id"]
        scene_ids = []
        clip_tokens = []
        longclip_tokens = []
        object_name_list = []
        object_id_list = []
        ann_id_list = []
        ann_ids = []
        meta_datas = []
        view_dependents = []
        target_words, mod_words = [], []
        # get pc
        filename = osp.join(self.data_root, self.prefix, scene_id + self.suffix)
        data = torch.load(filename)
        data = self.transform_train(*data) if self.training else self.transform_test(*data)
        if data is None:
            return None
        xyz, xyz_middle, rgb, semantic_label, instance_label, valid_idx = data

        coord = torch.from_numpy(xyz).long()
        coord_float = torch.from_numpy(xyz_middle)
        feat = torch.from_numpy(rgb).float()
        if self.training:
            feat += torch.randn(feat.size(1)) * 0.1
        instance_label = torch.from_numpy(instance_label)

        # unprocessed points
        point_cloud_path = osp.join(self.data_root, self.prefix, scene_id + "_vh_clean_2.ply")
        pcd = o3d.io.read_point_cloud(point_cloud_path)
        points = np.asarray(pcd.points)[valid_idx]
        points = torch.from_numpy(np.append(points, np.ones((points.shape[0], 1)), axis = -1)).to(torch.float64)
        
        img_path = osp.join(CONF.PATH.DATA, 'processed', 'img_rgb', scene_id + '_rgb.pth')
        data_dict["img_path"] = img_path
        img_feat_path = osp.join(CONF.PATH.DATA, 'processed', 'img_feat', scene_id + '_img_feat.pth')
        data_dict['img_feat_path'] = img_feat_path
        data_dict['clip_img_feat_path'] = osp.join(CONF.PATH.DATA, 'processed', 'img_feat_clip', scene_id + '_img_feat.pth')

        # visibility_mask
        data_dict["visibility_mask"] = self.get_points_visibility(points, scene_id)

        if CONF.nodetect:
            prop_path = osp.join(CONF.PATH.DATA, 'proposals', self.prefix, scene_id + '_proposal.pth')
            prop, objectness, prop_feat, prop_sem = torch.load(prop_path)
            prop = torch.from_numpy(prop)
            objectness = torch.from_numpy(objectness)
            prop_feat = torch.from_numpy(prop_feat)
            prop_sem = torch.from_numpy(prop_sem)
            data_dict["prop"] = prop[:, valid_idx]
            data_dict["objectness"] = objectness
            data_dict["prop_feat"] = prop_feat
            data_dict["prop_sem"] = prop_sem

        gt_masks = []
        for i in range(self.lang_num_max):
            if i < lang_num:
                if self.name == "multi3drefer":
                    object_id = self.scanrefer_new[idx][i]["object_ids"]
                    meta_data = {key:self.scanrefer_new[idx][i][key] 
                                for key in self.scanrefer_new[idx][i].keys() 
                                if key in ['eval_type','spatial','color','texture','shape']}
                    view_dependent = None
                elif self.name == "nr3d" or self.name == "sr3d":
                    object_id = [int(self.scanrefer_new[idx][i]["object_id"])]
                    meta_data = self.scanrefer_new[idx][i]['meta_data']
                    view_dependent = self.scanrefer_new[idx][i]['view_dependent']
                else:
                    object_id = [int(self.scanrefer_new[idx][i]["object_id"])]
                    meta_data = None
                    view_dependent = None 
                object_name = " ".join(self.scanrefer_new[idx][i]["object_name"].split("_"))
                ann_id = self.scanrefer_new[idx][i]["ann_id"]
                tg_word = self.scene_graphs[scene_id][str(ann_id)]['graph_node'][0]['target']
                mod_word = self.scene_graphs[scene_id][str(ann_id)]['graph_node'][0]['mod_text']
                gt_mask = self.get_ref_mask(instance_label, object_id)
                clip_token = clip.tokenize(self.scanrefer_new[idx][i]["description"].strip(), truncate=True)[0]
                longclip_token = longclip.tokenize(self.scanrefer_new[idx][i]["description"])[0]
                scene_ids.append(scene_id)
                object_id_list.append(object_id)
                ann_ids.append(int(ann_id))
                meta_datas.append(meta_data)
                view_dependents.append(view_dependent)

            clip_tokens.append(clip_token)
            longclip_tokens.append(longclip_token)
            target_words.append(tg_word)
            mod_words.append(mod_word)
            object_name_list.append(object_name)
            ann_id_list.append(int(ann_id))
            gt_masks.append(gt_mask)


        # ------------------------------- LABELS ------------------------------    
        object_cat_list = []
        for i in range(self.lang_num_max):
            object_cat = 17
            if object_name_list[i] in self.raw2label:
                object_cat = self.raw2label[object_name_list[i]]
            else:
                for tmp in object_name_list[i].split(" "):
                    if tmp in self.raw2label:
                        object_cat = self.raw2label[tmp]
            object_cat_list.append(object_cat)

        istrain = 0
        if self.mode == "train":
            istrain = 1

        data_dict["istrain"] = istrain
        data_dict["scan_idx"] = torch.from_numpy(np.array(idx).astype(np.int64))
        data_dict["lang_num"] = torch.from_numpy(np.array(lang_num).astype(np.int64))
        data_dict["load_time"] = time.time() - start

        data_dict["object_id_list"] = object_id_list
        data_dict["ann_id"] = ann_ids
        data_dict["target_words"] = target_words
        data_dict["mod_words"] = mod_words
        data_dict["meta_data"] = meta_datas
        data_dict["view_dependent"] = view_dependents
        data_dict["ann_id_list"] = torch.tensor(ann_id_list, dtype=torch.int64)
        data_dict["object_cat_list"] = torch.tensor(object_cat_list, dtype=torch.int64)
        # =====================
        # pc
        data_dict["scene_id"] = scene_id
        data_dict["scan_ids"] = scene_ids
        data_dict["coord"] = coord
        data_dict["coord_float"] = coord_float
        data_dict["feat"] = feat
        data_dict["rgb"] = ((torch.from_numpy(rgb).float()) + 1) / 2

        data_dict["gt_masks"] = torch.stack(gt_masks,0)
        data_dict["clip_token"] = torch.stack(clip_tokens, 0)
        data_dict["longclip_token"] = torch.stack(longclip_tokens, 0)
        return data_dict
    
    def collate_fn(self, batch):
        start = time.time()
        data_dict = {}
        load_time, istrain, lang_num, scan_idx = [], [], [], []
        object_id_list, ann_id_list, object_cat_list = [],[],[]
        scene_ids = []
        scan_ids = []
        ann_ids = []
        target_words, mod_words = [], []
        meta_datas = []
        view_dependents = []
        coords = []
        coords_float = []
        feats = []
        rgb = []
        prop, objectness, prop_feat, prop_sem = [], [], [], []
        clip_tokens, longclip_tokens = [], []

        img_path = []
        img_feat_path, visibility_mask, clip_img_feat_path = [], [], []
        gt_masks = []

        batch_id = 0
        total_point = 0
        batch_offsets = [0]
        for data in batch:
            if data is None:
                continue
            # (scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num,
            #  inst_pointnum, inst_cls, pt_offset_label) = data

            scene_ids.append(data["scene_id"])
            scan_ids.extend(data["scan_ids"])
            coord = data["coord"]
            coords.append(torch.cat([coord.new_full((coord.size(0), 1), batch_id), coord], 1))
            coords_float.append(data["coord_float"])
            feats.append(data["feat"])
            rgb.append(data["rgb"])
            total_point += coord.shape[0]
            batch_offsets.append(total_point)
            if CONF.nodetect:
                prop.append(data["prop"])
                objectness.append(data["objectness"])
                prop_feat.append(data["prop_feat"])
                prop_sem.append(data["prop_sem"])
            batch_id += 1

            istrain.append(data["istrain"])
            lang_num.append(data["lang_num"])
            load_time.append(data["load_time"])
            scan_idx.append(data["scan_idx"])
            object_id_list.extend(data["object_id_list"])
            ann_ids.extend(data["ann_id"])
            target_words.extend(data["target_words"])
            mod_words.extend(data["mod_words"])
            ann_id_list.append(data["ann_id_list"])
            object_cat_list.append(data["object_cat_list"])
            meta_datas.extend(data["meta_data"])
            view_dependents.extend(data["view_dependent"])

            gt_masks.append(data["gt_masks"])
            clip_tokens.append(data["clip_token"])
            longclip_tokens.append(data["longclip_token"])

            img_path.append(data["img_path"])
            img_feat_path.append(data["img_feat_path"])
            clip_img_feat_path.append(data['clip_img_feat_path'])
            visibility_mask.append(data["visibility_mask"])

        assert batch_id > 0, 'empty batch'
        # if batch_id < len(batch):
        #     self.logger.info(f'batch is truncated from size {len(batch)} to {batch_id}')

        # merge all the scenes in the batch
        coords = torch.cat(coords, 0)  # long (N, 1 + 3), the batch item idx is put in coords[:, 0]
        batch_idxs = coords[:, 0].int()
        coords_float = torch.cat(coords_float, 0).to(torch.float32)  # float (N, 3)
        feats = torch.cat(feats, 0)  # float (N, C)
        rgb = torch.cat(rgb, 0)

        spatial_shape = np.clip(
            coords.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
        voxel_coords, v2p_map, p2v_map = voxelization_idx(coords, batch_id)
        
        data_dict["istrain"] = istrain
        data_dict["lang_num"] = torch.tensor(lang_num, dtype=torch.int64)
        data_dict["scan_idx"] = torch.stack(scan_idx, 0)
        data_dict["object_id_list"] = object_id_list
        data_dict["ann_ids"] = ann_ids
        data_dict["target_words"] = target_words
        data_dict["mod_words"] = mod_words
        data_dict["ann_id_list"] = torch.stack(ann_id_list, 0)
        data_dict["object_cat_list"] = torch.stack(object_cat_list, 0)
        data_dict["meta_datas"] = meta_datas
        data_dict["view_dependents"] = view_dependents

        # pc
        data_dict["scene_ids"] = scene_ids
        data_dict["scan_ids"] = scan_ids
        data_dict["coords"] = coords
        data_dict["coords_float"] = coords_float
        data_dict["feats"] = feats
        data_dict["rgb"] = rgb
        if CONF.nodetect:
            data_dict["prop"] = prop
            data_dict["objectness"] = torch.stack(objectness, 0)
            data_dict["prop_feat"] = torch.stack(prop_feat, 0)
            data_dict["prop_sem"] = torch.stack(prop_sem, 0)

        data_dict["gt_masks"] = gt_masks
        data_dict["clip_token"] = torch.stack(clip_tokens, 0)
        data_dict["longclip_token"] = torch.stack(longclip_tokens, 0)

        data_dict["batch_idxs"] = batch_idxs
        data_dict["voxel_coords"] = voxel_coords
        data_dict["p2v_map"] = p2v_map
        data_dict["v2p_map"] = v2p_map
        data_dict["spatial_shape"] = spatial_shape
        data_dict["batch_size"] = batch_id

        data_dict["batch_offset"] = torch.tensor(batch_offsets, dtype=torch.int)
        collate_time = time.time() - start
        load_time.append(collate_time)
        data_dict["load_time"] = torch.tensor(load_time, dtype=torch.float64)
        data_dict["split"] = self.prefix

        data_dict["img_path"] = img_path
        data_dict['img_feat_path'] = img_feat_path
        data_dict['clip_img_feat_path'] = clip_img_feat_path
        data_dict["visibility_mask"] = visibility_mask
        return data_dict



#########################
# BRIEF Text decoupling #
#########################
def Scene_graph_parse(caption):
    caption = ' '.join(caption.replace(',', ' , ').split())

    # some error or typo in ScanRefer.
    caption = ' '.join(caption.replace("'m", "am").split())
    caption = ' '.join(caption.replace("'s", "is").split())
    caption = ' '.join(caption.replace("2-tiered", "2 - tiered").split())
    caption = ' '.join(caption.replace("4-drawers", "4 - drawers").split())
    caption = ' '.join(caption.replace("5-drawer", "5 - drawer").split())
    caption = ' '.join(caption.replace("8-hole", "8 - hole").split())
    caption = ' '.join(caption.replace("7-shaped", "7 - shaped").split())
    caption = ' '.join(caption.replace("2-door", "2 - door").split())
    caption = ' '.join(caption.replace("3-compartment", "3 - compartment").split())
    caption = ' '.join(caption.replace("computer/", "computer /").split())
    caption = ' '.join(caption.replace("3-tier", "3 - tier").split())
    caption = ' '.join(caption.replace("3-seater", "3 - seater").split())
    caption = ' '.join(caption.replace("4-seat", "4 - seat").split())
    caption = ' '.join(caption.replace("theses", "these").split())


    # nr3d = True
    # some error or typo in NR3D.
    # if nr3d:
    # caption = ' '.join(caption.replace('.', ' .').split())
    # caption = ' '.join(caption.replace(';', ' ; ').split())
    # caption = ' '.join(caption.replace('-', ' ').split())
    # caption = ' '.join(caption.replace('"', ' ').split())
    # caption = ' '.join(caption.replace('?', ' ').split())
    # caption = ' '.join(caption.replace("*", " ").split())
    # caption = ' '.join(caption.replace(':', ' ').split())
    # caption = ' '.join(caption.replace('$', ' ').split())
    # caption = ' '.join(caption.replace("#", " ").split())
    # caption = ' '.join(caption.replace("/", " / ").split())
    # caption = ' '.join(caption.replace("you're", "you are").split())
    # caption = ' '.join(caption.replace("isn't", "is not").split())
    # caption = ' '.join(caption.replace("thats", "that is").split())
    # caption = ' '.join(caption.replace("theres", "there is").split())
    # caption = ' '.join(caption.replace("doesn't", "does not").split())
    # caption = ' '.join(caption.replace("doesnt", "does not").split())
    # caption = ' '.join(caption.replace("itis", "it is").split())
    # caption = ' '.join(caption.replace("left-hand", "left - hand").split())
    # caption = ' '.join(caption.replace("[", " [ ").split())
    # caption = ' '.join(caption.replace("]", " ] ").split())
    # caption = ' '.join(caption.replace("(", " ( ").split())
    # caption = ' '.join(caption.replace(")", " ) ").split())
    # caption = ' '.join(caption.replace("wheel-chair", "wheel - chair").split())
    # caption = ' '.join(caption.replace(";s", "is").split())
    # caption = ' '.join(caption.replace("tha=e", "the").split())
    # caption = ' '.join(caption.replace("it’s", "it is").split())
    # caption = ' '.join(caption.replace("’s", " is").split())
    # caption = ' '.join(caption.replace("isnt", "is not").split())
    # caption = ' '.join(caption.replace("Don't", "Do not").split())
    # caption = ' '.join(caption.replace("arent", "are not").split())
    # caption = ' '.join(caption.replace("cant", "can not").split())
    # caption = ' '.join(caption.replace("you’re", "you are").split())
    # caption = ' '.join(caption.replace('!', ' !').split())
    # caption = ' '.join(caption.replace('id the', ' , the').split())
    # caption = ' '.join(caption.replace('youre', 'you are').split())

    # caption = ' '.join(caption.replace("'", ' ').split())
    # caption = ' '.join(caption.replace("``", ' ').split())

    if caption[0] == "'":
        caption = caption[1:]
    if caption[-1] == "'":
        caption = caption[:-1]

    # text parsing
    graph_node, graph_edge = sng_parser.parse(caption)

    # NOTE If no node is parsed, add "this is an object ." at the beginning of the sentence
    if (len(graph_node) < 1) or \
        (len(graph_node) > 0 and graph_node[0]["node_id"] != 0):
        caption = "This is an object . " + caption
        # parse again
        graph_node, graph_edge = sng_parser.parse(caption)


    # auxi object
    auxi_entity = None
    for node in graph_node:
        if (node["node_id"] != 0) and (node["node_type"] == "Object"):
            auxi_entity = node
            break
    
    return {
        "graph_node": graph_node,
        "graph_edge": graph_edge,
        "auxi_entity": auxi_entity,
        "utterance": caption
    }
