"""Dataset and data loader for ReferIt3D."""

from copy import deepcopy
import csv
import h5py
import json
import multiprocessing as mp
import os
import random

import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import RobertaTokenizerFast
import wandb

from models.losses import _iou3d_par, box_cxcyczwhd_to_xyzxyz
from scannet.model_util_scannet import ScannetDatasetConfig
from scannet.scannet_utils import read_label_mapping
from .scannet_classes import REL_ALIASES, UNIQUE_VIEW_DEP_RELS, VIEW_DEP_RELS
from .scannet_dataset import unpickle_data, get_positive_map
from sunrgbd.sunrgbd_utils import extract_pc_in_box3d

import ipdb
st = ipdb.set_trace

NUM_CLASSES = 485
DC = ScannetDatasetConfig(NUM_CLASSES)
MAX_NUM_OBJ = 132


class Joint3DDataset(Dataset):
    """Dataset utilities for ReferIt3D."""

    def __init__(self, dataset_dict={'sr3d': 1, 'scannet': 10},
                 test_dataset='sr3d',
                 split='train', num_points=50000,
                 use_color=False, use_height=False, overfit=False,
                 detect_intermediate=False,
                 filter_relations=False,
                 visualize=False,
                 augment=True, use_oriented_boxes=False, rotate_pc=False,
                 use_multiview=False, train_viewpoint_module=False,
                 use_predicted_viewpoint=False,
                 butd=False, butd_gt=False, butd_cls=False,
                 run_on_target_phrases=False, hide_target=False,
                 hide_anchors=False, append_anchors=False):
        """Initialize dataset (here for ReferIt3D utterances)."""
        self.dataset_dict = dataset_dict
        self.test_dataset = test_dataset
        self.split = split
        self.num_points = num_points
        self.use_color = use_color
        self.use_height = use_height
        self.overfit = overfit
        self.rotate_pc = rotate_pc
        self.use_predicted_viewpoint = use_predicted_viewpoint
        self.detect_intermediate = detect_intermediate and not run_on_target_phrases
        self.filter_relations = filter_relations
        self.use_oriented_boxes = use_oriented_boxes
        self.augment = augment
        self.use_multiview = use_multiview
        self.train_viewpoint_module = train_viewpoint_module
        self.rel_filter_list = ['on']
        self.data_path = './dataset/language_grounding/'
        self.visualize = visualize
        self.butd = butd
        self.butd_gt = butd_gt
        self.butd_cls = butd_cls
        self.load_detected_boxes = False
        self.run_on_target_phrases = run_on_target_phrases
        self.hide_target = hide_target
        self.hide_anchors = hide_anchors
        self.append_anchors = append_anchors

        if self.butd:
            self.load_detected_boxes = True

        if self.train_viewpoint_module:
            self.filter_relations = True
            self.augment = False
            self.rel_filter_list = UNIQUE_VIEW_DEP_RELS

        # do not augment
        if (self.use_oriented_boxes or rotate_pc) and not self.butd:
            self.load_detected_boxes = True
            self.augment = False

        self.mean_rgb = np.array([109.8, 97.2, 83.8]) / 256
        self.label_map = read_label_mapping(
            'scannet/meta_data/scannetv2-labels.combined.tsv',
            label_from='raw_category',
            label_to='nyu40id' if NUM_CLASSES == 18 else 'id'
        )
        with open(self.data_path + 'extra/object_oriented_bboxes/object_oriented_bboxes_aligned_scans.json') as fid:
            self.oriented_bboxes_mapping = json.load(fid)
        self.multiview_path = os.path.join(
            f'{self.data_path}/scanrefer_2d_feats',
            "enet_feats_maxpool.hdf5"
        )
        self.multiview_data = {}
        self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
        if os.path.exists('cls_results.json'):
            with open('cls_results.json') as fid:
                self.cls_results = json.load(fid)  # {scan_id: [class0, ...]}

        # load
        print('Loading %s files, take a breath!' % split)
        if split == 'train100':
            split = 'train'
        _, self.scans = unpickle_data(f'{self.data_path}/{split}_v2scans.pkl')
        if self.split != 'train':
            self.annos = self.load_annos(test_dataset)
        else:
            self.annos = []
            for dset, cnt in dataset_dict.items():
                if cnt > 0:
                    _annos = self.load_annos(dset)
                    self.annos += (_annos * cnt)

        if self.visualize:
            wandb.init(project="vis", name="debug")

    def load_annos(self, dset):
        """Load annotations of given dataset."""
        loaders = {
            'nr3d': self.load_nr3d_annos,
            'sr3d': self.load_sr3d_annos,
            'sr3d+': self.load_sr3dplus_annos,
            'scanrefer': self.load_scanrefer_annos,
            'scannet': self.load_scannet_annos
        }
        annos = loaders[dset]()
        if self.run_on_target_phrases:
            for anno in annos:
                anno['utterance'] = anno['target']
        return annos

    def load_sr3dplus_annos(self):
        """Load annotations of sr3d/sr3d+."""
        return self.load_sr3d_annos(dset='sr3d+')

    def load_sr3d_annos(self, dset='sr3d'):
        """Load annotations of sr3d/sr3d+."""
        split = self.split
        if split == 'train100':
            split = 'train'
        elif split == 'val':
            split = 'test'
        with open('%s/extra/sr3d_%s_scans.txt' % (self.data_path, split)) as f:
            scan_ids = set(eval(f.read()))
        with open(self.data_path + 'refer_it_3d/%s.csv' % dset) as f:
            csv_reader = csv.reader(f)
            headers = next(csv_reader)
            headers = {header: h for h, header in enumerate(headers)}
            annos = [
                {
                    'scan_id': line[headers['scan_id']],
                    'target_id': int(line[headers['target_id']]),
                    'distractor_ids': eval(line[headers['distractor_ids']]),
                    'utterance': line[headers['utterance']],
                    'target': line[headers['instance_type']],
                    'anchors': eval(line[headers['anchors_types']]),
                    'anchor_ids': eval(line[headers['anchor_ids']]),
                    'dataset': dset
                }
                for line in csv_reader
                if line[headers['scan_id']] in scan_ids
                and
                str(line[headers['mentions_target_class']]).lower() == 'true'
                and
                (
                    not self.filter_relations or
                    self._find_rel(line[headers['utterance']]) in self.rel_filter_list
                )
            ]
        if self.overfit:
            annos = annos[:256]
        elif self.split == 'train100':
            annos = annos[:100]
        return annos

    def load_nr3d_annos(self):
        """Load annotations of nr3d."""
        split = self.split
        if split == 'train100':
            split = 'train'
        elif split == 'val':
            split = 'test'
        with open('%s/extra/nr3d_%s_scans.txt' % (self.data_path, split)) as f:
            scan_ids = set(eval(f.read()))
        with open(self.data_path + 'refer_it_3d/nr3d.csv') as f:
            csv_reader = csv.reader(f)
            headers = next(csv_reader)
            headers = {header: h for h, header in enumerate(headers)}
            annos = [
                {
                    'scan_id': line[headers['scan_id']],
                    'target_id': int(line[headers['target_id']]),
                    'target': line[headers['instance_type']],
                    'utterance': line[headers['utterance']],
                    'anchor_ids': [],
                    'anchors': [],
                    'dataset': 'nr3d'
                }
                for line in csv_reader
                if line[headers['scan_id']] in scan_ids
                and
                str(line[headers['mentions_target_class']]).lower() == 'true'
            ]
        # Add distractor info
        for anno in annos:
            anno['distractor_ids'] = [
                ind
                for ind in range(len(self.scans[anno['scan_id']].three_d_objects))
                if self.scans[anno['scan_id']].get_object_instance_label(ind)
                == anno['target']
            ]
        # Filter out sentences that do not explicitly mention the target class
        annos = [anno for anno in annos if anno['target'] in anno['utterance']]
        if self.overfit:
            annos = annos[:256]
        elif self.split == 'train100':
            annos = annos[:100]
        return annos

    def load_scanrefer_annos(self):
        """Load annotations of ScanRefer."""
        _path = './dataset/datasets/scanrefer/ScanRefer_filtered'
        split = self.split
        if split == 'train100':
            split = 'train'
        elif split in ('val', 'test'):
            split = 'val'
        with open(_path + '_%s.txt' % split) as f:
            scan_ids = [line.rstrip().strip('\n') for line in f.readlines()]
        with open(_path + '_%s.json' % split) as f:
            reader = json.load(f)
        annos = [
            {
                'scan_id': anno['scene_id'],
                'target_id': int(anno['object_id']),
                'distractor_ids': [],
                'utterance': ' '.join(anno['token']),
                'target': ' '.join(str(anno['object_name']).split('_')),
                'anchors': [],
                'anchor_ids': [],
                'dataset': 'scanrefer'
            }
            for anno in reader
            if anno['scene_id'] in scan_ids
        ]
        # Fix missing target reference
        for anno in annos:
            if anno['target'] not in anno['utterance']:
                anno['utterance'] = (
                    ' '.join(anno['utterance'].split(' . ')[0].split()[:-1])
                    + ' ' + anno['target'] + ' . '
                    + ' . '.join(anno['utterance'].split(' . ')[1:])
                )
        # Add distractor info
        for anno in annos:
            anno['distractor_ids'] = [
                ind
                for ind in range(len(self.scans[anno['scan_id']].three_d_objects))
                if self.scans[anno['scan_id']].get_object_instance_label(ind)
                == anno['target']
            ][:32]
        if self.overfit:
            annos = annos[:256]
        elif self.split == 'train100':
            annos = annos[:100]
        return annos

    def load_scannet_annos(self):
        """Load annotations of scannet."""
        split = 'train' if self.split == 'train' else 'val'
        with open('scannet/meta_data/scannetv2_%s.txt' % split) as f:
            scan_ids = [line.rstrip() for line in f]
        annos = []
        for scan_id in scan_ids:
            scan = self.scans[scan_id]
            # Ignore scans that have no object in our vocabulary
            keep = np.array([
                self.label_map[
                    scan.get_object_instance_label(ind)
                ] in DC.nyu40id2class
                for ind in range(len(scan.three_d_objects))
            ])
            if keep.any():
                # this will get populated randomly each time
                annos.append({
                    'scan_id': scan_id,
                    'target_id': [],
                    'distractor_ids': [],
                    'utterance': '',
                    'target': [],
                    'anchors': [],
                    'anchor_ids': [],
                    'dataset': 'scannet'
                })
        if self.split != 'train':
            annos = [
                anno for a, anno in enumerate(annos)
                if a not in [37, 38, 77, 78, 79, 86, 207, 208, 209]
            ]
        if self.overfit:
            annos = annos[:50]
        return annos

    def _sample_classes(self, scan_id):
        """Sample classes for the scannet detection sentences."""
        scan = self.scans[scan_id]
        sampled_classes = set([
            self.label_map[scan.get_object_instance_label(ind)]
            for ind in range(len(scan.three_d_objects))
        ])
        sampled_classes = list(sampled_classes & set(DC.nyu40id2class))
        # sample 10 classes
        if self.split == 'train':
            if len(sampled_classes) > 10:
                sampled_classes = random.sample(sampled_classes, 8)
            ret = [DC.class2type[DC.nyu40id2class[i]] for i in sampled_classes]
        else:
            ret = [
                'cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
                'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
                'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub',
                'trash bin'
            ]
        return ret

    def _create_scannet_utterance(self, sampled_classes):
        neg_names = []
        while len(neg_names) < 4 and self.split == 'train':
            _ind = np.random.randint(0, len(DC.class2type))
            if DC.class2type[_ind] not in neg_names + sampled_classes:
                neg_names.append(DC.class2type[_ind])
        if self.split == 'train':
            mixed_names = sorted(list(set(sampled_classes + neg_names)))
            random.shuffle(mixed_names)
        else:
            mixed_names = sampled_classes
        utterance = ' . '.join(mixed_names) + ' .'
        return utterance

    def _load_detected_boxes(
        self, split, scan_id,
        all_detected_bboxes, all_detected_bbox_label_mask,
        detected_class_ids,
        augmentations
    ):
        detected_dict = np.load(
            f'{self.data_path}group_free_pred_bboxes_{split}/{scan_id}.npy',
            allow_pickle=True
        ).item()

        all_bboxes_ = np.array(detected_dict['box'])
        classes = detected_dict['class']
        cid = np.array([DC.nyu40id2class[
            self.label_map[c]] for c in detected_dict['class']
        ])

        all_bboxes_ = np.concatenate((
            (all_bboxes_[:, :3] + all_bboxes_[:, 3:]) * 0.5,
            all_bboxes_[:, 3:] - all_bboxes_[:, :3]
        ), 1)

        assert len(classes) < MAX_NUM_OBJ
        assert len(classes) == all_bboxes_.shape[0]

        num_objs = len(classes)
        all_detected_bboxes[:num_objs] = all_bboxes_
        all_detected_bbox_label_mask[:num_objs] = np.array([True] * num_objs)
        detected_class_ids[:num_objs] = cid
        if self.augment and self.split == 'train':
            all_detected_pts = box2points(all_detected_bboxes).reshape(-1, 3)
            all_detected_pts = rot_z(all_detected_pts, augmentations['theta_z'])
            all_detected_pts = rot_x(all_detected_pts, augmentations['theta_x'])
            all_detected_pts = rot_y(all_detected_pts, augmentations['theta_y'])
            if augmentations.get('yz_flip', False):
                all_detected_pts[:, 0] = -all_detected_pts[:, 0]
            if augmentations.get('xz_flip', False):
                all_detected_pts[:, 1] = -all_detected_pts[:, 1]
            all_detected_pts += augmentations['shift']
            all_detected_pts *= augmentations['scale']
            all_detected_bboxes = points2box(all_detected_pts.reshape(-1, 8, 3))
        return (
            all_detected_bboxes, all_detected_bbox_label_mask,
            detected_class_ids
        )

    def _load_multiview(self, scan_id):
        """Load multi-view data of given scan-id."""
        pid = mp.current_process().pid
        if pid not in self.multiview_data:
            self.multiview_data[pid] = h5py.File(
                self.multiview_path, "r", libver="latest"
            )
        return self.multiview_data[pid][scan_id]

    def _rotate_based_on_anchor(self, scan_id, anchor_id):
        """Rotate pc for the front viewpoint of a given anchor."""
        id_ = str(scan_id) + "_" + str(anchor_id)
        assert id_ in self.oriented_bboxes_mapping
        oriented_data = self.oriented_bboxes_mapping[id_]
        box = oriented_data['obj_bbox']
        orot = np.array(oriented_data['obj_rot'])

        # make pc face the object
        pc = np.copy(self.scans[scan_id].pc)
        pc = pc - np.array(box[:3])[None]
        pc = np.concatenate([pc, np.ones((pc.shape[0], 1))], 1)
        pc = (np.linalg.inv(orot) @ pc.T).T[:, :3]
        return pc

    def _augment(self, pc, color, rotate):
        augmentations = {}

        # Rotate/flip only if we don't have a view_dep sentence
        if rotate:
            theta_z = 90*np.random.randint(0, 4) + (2*np.random.rand() - 1) * 5
            # Flipping along the YZ plane
            augmentations['yz_flip'] = np.random.random() > 0.5
            if augmentations['yz_flip']:
                pc[:, 0] = -pc[:, 0]
            # Flipping along the XZ plane
            augmentations['xz_flip'] = np.random.random() > 0.5
            if augmentations['xz_flip']:
                pc[:, 1] = -pc[:, 1]
        else:
            theta_z = (2*np.random.rand() - 1) * 5
        augmentations['theta_z'] = theta_z
        pc[:, :3] = rot_z(pc[:, :3], theta_z)
        # Rotate around x
        theta_x = (2*np.random.rand() - 1) * 2.5
        augmentations['theta_x'] = theta_x
        pc[:, :3] = rot_x(pc[:, :3], theta_x)
        # Rotate around y
        theta_y = (2*np.random.rand() - 1) * 2.5
        augmentations['theta_y'] = theta_y
        pc[:, :3] = rot_y(pc[:, :3], theta_y)

        # Add noise
        noise = np.random.rand(len(pc), 3) * 5e-3
        augmentations['noise'] = noise
        pc[:, :3] = pc[:, :3] + noise

        # Translate/shift
        augmentations['shift'] = np.random.random((3,))[None, :] - 0.5
        pc[:, :3] += augmentations['shift']

        # Scale
        augmentations['scale'] = 0.98 + 0.04*np.random.random()
        pc[:, :3] *= augmentations['scale']

        # Color
        if color is not None:
            color += self.mean_rgb
            color *= 0.98 + 0.04*np.random.random((len(color), 3))
            color -= self.mean_rgb
        return pc, color, augmentations

    def _get_pc(self, anno, scan):
        """Return a point cloud representation of current scene."""
        scan_id = anno['scan_id']
        rel_name = "none"
        if anno['dataset'].startswith('sr3d'):
            rel_name = self._find_rel(anno['utterance'])

        # a. Rotate based on anchor front viewpoint
        if self.rotate_pc and rel_name in UNIQUE_VIEW_DEP_RELS:
            # load anchor
            anchor_id = anno['anchor_ids'][0]
            scan.pc = self._rotate_based_on_anchor(scan_id, anchor_id)

        # b. Color
        color = None
        if self.use_color:
            color = np.copy(scan.color - self.mean_rgb)

        # c .Height
        height = None
        if self.use_height:
            floor_height = np.percentile(scan.pc[:, 2], 0.99)
            height = np.expand_dims(scan.pc[:, 2] - floor_height, 1)

        # d. Multi-view 2d features
        multiview_data = None
        if self.use_multiview:
            multiview_data = self._load_multiview(scan_id)

        # e. Augmentations
        augmentations = {}
        if self.split == 'train' and self.augment:
            rotate_natural = (
                anno['dataset'] in ('nr3d', 'scanrefer')
                and self._augment_nr3d(anno['utterance'])
            )
            rotate_sr3d = (
                anno['dataset'].startswith('sr3d')
                and rel_name not in VIEW_DEP_RELS
            )
            rotate_else = anno['dataset'] == 'scannet'
            rotate = rotate_sr3d or rotate_natural or rotate_else
            pc, color, augmentations = self._augment(scan.pc, color, rotate)
            scan.pc = pc

        # f. Concatenate representations
        point_cloud = scan.pc
        if color is not None:
            point_cloud = np.concatenate((point_cloud, color), 1)
        if height is not None:
            point_cloud = np.concatenate([point_cloud, height], 1)
        if multiview_data is not None:
            point_cloud = np.concatenate([point_cloud, multiview_data], 1)

        # g. Subsample point cloud
        if self.split != 'train':
            np.random.seed(1184)
        choices = np.random.choice(
            point_cloud.shape[0],
            self.num_points,
            replace=len(point_cloud) < self.num_points
        )
        point_cloud = point_cloud[choices]
        og_color = scan.color[choices]  # original color for visualizations
        augmentations['choices'] = choices
        return point_cloud, augmentations, og_color

    def _get_token_positive_map(self, anno):
        """Return correspondence of boxes to tokens."""
        # Token start-end span in characters
        caption = ' '.join(anno['utterance'].replace(',', ' ,').split())
        caption = ' ' + caption + ' '
        tokens_positive = np.zeros((MAX_NUM_OBJ, 2))
        if isinstance(anno['target'], list):
            cat_names = anno['target']
        else:
            cat_names = [anno['target']]
        if self.detect_intermediate:
            cat_names += anno['anchors']
        for c, cat_name in enumerate(cat_names):
            start_span = caption.find(' ' + cat_name + ' ')
            len_ = len(cat_name)
            if start_span < 0:
                start_span = caption.find(' ' + cat_name)
                len_ = len(caption[start_span+1:].split()[0])
            if start_span < 0:
                start_span = caption.find(cat_name)
                orig_start_span = start_span
                while caption[start_span - 1] != ' ':
                    start_span -= 1
                len_ = len(cat_name) + orig_start_span - start_span
                while caption[len_ + start_span] != ' ':
                    len_ += 1
            end_span = start_span + len_
            assert start_span > -1, caption
            assert end_span > 0, caption
            tokens_positive[c][0] = start_span
            tokens_positive[c][1] = end_span

        # Positive map (for soft token prediction)
        tokenized = self.tokenizer.batch_encode_plus(
            [' '.join(anno['utterance'].replace(',', ' ,').split())],
            padding="longest", return_tensors="pt"
        )
        positive_map = np.zeros((MAX_NUM_OBJ, 256))
        gt_map = get_positive_map(tokenized, tokens_positive[:len(cat_names)])
        positive_map[:len(cat_names)] = gt_map
        return tokens_positive, positive_map

    def _get_target_boxes(self, anno, scan, augmentations):
        """Return gt boxes to detect."""
        bboxes = np.zeros((MAX_NUM_OBJ, 6))
        if isinstance(anno['target_id'], list):  # scannet
            tids = anno['target_id']
        else:  # referit dataset
            tids = [anno['target_id']]
            if self.detect_intermediate:
                tids += anno.get('anchor_ids', [])
        point_instance_label = -np.ones(len(scan.pc))
        for t, tid in enumerate(tids):
            point_instance_label[scan.three_d_objects[tid]['points']] = t
        point_instance_label = point_instance_label[augmentations['choices']]

        bboxes[:len(tids)] = np.stack([
            scan.get_object_bbox(tid).reshape(-1) for tid in tids
        ])
        bboxes = np.concatenate((
            (bboxes[:, :3] + bboxes[:, 3:]) * 0.5,
            bboxes[:, 3:] - bboxes[:, :3]
        ), 1)
        if self.split == 'train' and self.augment:  # jitter boxes
            bboxes[:len(tids)] *= (0.95 + 0.1*np.random.random((len(tids), 6)))
        bboxes[len(tids):, :3] = 1000
        box_label_mask = np.zeros(MAX_NUM_OBJ)
        box_label_mask[:len(tids)] = 1
        return bboxes, box_label_mask, point_instance_label

    def _get_scene_objects(self, scan):
        # Objects to keep
        keep_ = np.array([
            self.label_map[
                scan.get_object_instance_label(ind)
            ] in DC.nyu40id2class
            for ind in range(len(scan.three_d_objects))
        ])[:MAX_NUM_OBJ]
        keep = np.array([False] * MAX_NUM_OBJ)
        keep[:len(keep_)] = keep_

        # Class ids
        cid = np.array([
            DC.nyu40id2class[self.label_map[scan.get_object_instance_label(k)]]
            for k, kept in enumerate(keep) if kept
        ])
        class_ids = np.zeros((MAX_NUM_OBJ,))
        class_ids[keep] = cid

        # Object boxes
        all_bboxes = np.zeros((MAX_NUM_OBJ, 6))
        all_bboxes_ = np.stack([
            scan.get_object_bbox(k).reshape(-1)
            for k, kept in enumerate(keep) if kept
        ])
        # cx, cy, cz, w, h, d
        all_bboxes_ = np.concatenate((
            (all_bboxes_[:, :3] + all_bboxes_[:, 3:]) * 0.5,
            all_bboxes_[:, 3:] - all_bboxes_[:, :3]
        ), 1)
        all_bboxes[keep] = all_bboxes_
        if self.split == 'train' and self.augment:
            all_bboxes *= (0.95 + 0.1*np.random.random((len(all_bboxes), 6)))

        # Which boxes we're interested for
        all_bbox_label_mask = keep
        return class_ids, all_bboxes, all_bbox_label_mask

    def _get_detected_objects(self, split, scan_id, augmentations):
        all_detected_bboxes = np.zeros((MAX_NUM_OBJ, 6))
        all_detected_bbox_label_mask = np.array([False] * MAX_NUM_OBJ)
        detected_class_ids = np.zeros((MAX_NUM_OBJ,))
        # if self.load_detected_boxes:
        (
            all_detected_bboxes, all_detected_bbox_label_mask,
            detected_class_ids
        ) = self._load_detected_boxes(
            split, scan_id,
            all_detected_bboxes, all_detected_bbox_label_mask,
            detected_class_ids,
            augmentations
        )
        return (
            all_detected_bboxes, all_detected_bbox_label_mask,
            detected_class_ids
        )

    def __getitem__(self, index):
        """Get current batch for input index."""
        split = self.split
        if split == 'train100':
            split = 'train'

        # Read annotation
        anno = self.annos[index]
        scan = deepcopy(self.scans[anno['scan_id']])

        # Populate anno (used only for scannet)
        if anno['dataset'] == 'scannet':
            sampled_classes = self._sample_classes(anno['scan_id'])
            utterance = self._create_scannet_utterance(sampled_classes)
            anno['target_id'] = np.where(np.array([
                self.label_map[
                    scan.get_object_instance_label(ind)
                ] in DC.nyu40id2class
                and
                DC.class2type[DC.nyu40id2class[self.label_map[
                    scan.get_object_instance_label(ind)
                ]]] in sampled_classes
                for ind in range(len(scan.three_d_objects))
            ])[:MAX_NUM_OBJ])[0].tolist()
            anno['target'] = [
                DC.class2type[DC.nyu40id2class[self.label_map[
                    scan.get_object_instance_label(ind)
                ]]]
                for ind in anno['target_id']
            ]
            anno['utterance'] = utterance

        # Point cloud representation
        point_cloud, augmentations, og_color = self._get_pc(anno, scan)

        # "Target" boxes: append anchors if they're to be detected
        gt_bboxes, box_label_mask, point_instance_label = self._get_target_boxes(
            anno, scan, augmentations
        )

        # Positive map for soft-token and contrastive losses
        tokens_positive, positive_map = self._get_token_positive_map(anno)

        # Scene gt boxes
        (
            class_ids, all_bboxes, all_bbox_label_mask
        ) = self._get_scene_objects(scan)

        # Detected boxes
        (
            all_detected_bboxes, all_detected_bbox_label_mask,
            detected_class_ids
        ) = self._get_detected_objects(split, anno['scan_id'], augmentations)

        # Assume a perfect object detector
        if self.butd_gt:
            all_detected_bboxes = all_bboxes
            all_detected_bbox_label_mask = all_bbox_label_mask
            detected_class_ids = class_ids

        # Assume a perfect object proposal stage
        if self.butd_cls:
            all_detected_bboxes = all_bboxes
            all_detected_bbox_label_mask = all_bbox_label_mask
            detected_class_ids = np.zeros((len(all_bboxes,)))
            classes = np.array(self.cls_results[anno['scan_id']])
            detected_class_ids[all_bbox_label_mask] = classes[classes > -1]

        # Find points belonging to each detected box
        points_to_boxes = np.zeros((len(point_cloud), 7))
        if self.butd or self.butd_gt or self.butd_cls:
            detected_box_points = box2points(all_detected_bboxes)
            pts = detected_box_points[all_detected_bbox_label_mask]
            for i, obj in enumerate(pts):
                _, indices = extract_pc_in_box3d(point_cloud[:, :3], obj)
                points_to_boxes[indices, -7:-1] = all_detected_bboxes[i][None]
                points_to_boxes[indices, -1] = (
                    detected_class_ids[i][None]
                    / NUM_CLASSES
                )

        # Visualize for debugging
        if self.visualize and anno['dataset'].startswith('sr3d'):
            self._visualize_scene(anno, point_cloud, og_color, all_bboxes)

        # Return
        ret_dict = {
            'box_label_mask': box_label_mask.astype(np.float32),
            'center_label': gt_bboxes[:, :3].astype(np.float32),
            'sem_cls_label': class_ids.astype(np.int64),
            'size_gts': gt_bboxes[:, 3:].astype(np.float32)
        }
        if self.hide_target:
            hide_target_mask = self._hide_target(
                all_detected_bboxes.astype(np.float32)[all_detected_bbox_label_mask.astype(np.bool8)],
                gt_bboxes[0]
            )
            all_detected_bbox_label_mask[:len(hide_target_mask)] &= hide_target_mask.astype(np.bool8)
        if self.hide_anchors:
            hide_anchor_mask = self._hide_anchors(
                all_detected_bboxes.astype(np.float32)[all_detected_bbox_label_mask.astype(np.bool8)],
                gt_bboxes[1:box_label_mask.sum().astype(int)]
            )
            all_detected_bbox_label_mask[:len(hide_anchor_mask)] &= hide_anchor_mask.astype(np.bool8)
        if self.append_anchors:
            num_anchors = box_label_mask.sum().astype(int) - 1
            all_detected_bboxes[-num_anchors:] = gt_bboxes[1:box_label_mask.sum().astype(int)]
            all_detected_bbox_label_mask[-num_anchors:] = True
            detected_class_ids[-num_anchors:] = np.array(class_ids[anno['anchor_ids']])
        ret_dict.update({
            "scan_ids": anno['scan_id'],
            "point_clouds": point_cloud.astype(np.float32),
            "utterances": ' '.join(anno['utterance'].replace(',', ' ,').split()),
            "tokens_positive": tokens_positive.astype(np.int64),
            "positive_map": positive_map.astype(np.float32),
            "relation": (
                self._find_rel(anno['utterance'])
                if anno['dataset'].startswith('sr3d')
                else "none"
            ),
            "target_name": scan.get_object_instance_label(
                anno['target_id'] if isinstance(anno['target_id'], int)
                else anno['target_id'][0]
            ),
            "target_id": (
                anno['target_id'] if isinstance(anno['target_id'], int)
                else anno['target_id'][0]
            ),
            "point_instance_label": point_instance_label.astype(np.int64),
            "all_bboxes": all_bboxes.astype(np.float32),
            "all_bbox_label_mask": all_bbox_label_mask.astype(np.bool8),
            "distractor_ids": np.array(
                anno['distractor_ids']
                + [-1] * (32 - len(anno['distractor_ids']))
            ).astype(int),
            "anchor_ids": np.array(
                anno['anchor_ids']
                + [-1] * (32 - len(anno['anchor_ids']))
            ).astype(int),
            "all_detected_boxes": all_detected_bboxes.astype(np.float32),
            "all_detected_bbox_label_mask": all_detected_bbox_label_mask.astype(np.bool8),
            "all_detected_class_ids": detected_class_ids.astype(np.int64),
            "points_to_boxes": points_to_boxes.astype(np.float32),
            "eul": np.array([0, 0, 0]).astype(np.float32),
            "is_view_dep": self._is_view_dep(anno['utterance']),
            "is_hard": len(anno['distractor_ids']) > 2,
            "is_unique": len(anno['distractor_ids']) == 1,
            "target_cid": (
                class_ids[anno['target_id']]
                if isinstance(anno['target_id'], int)
                else class_ids[anno['target_id'][0]]
            ),
            "detector_succeeded": self._detector_succeeded(
                all_detected_bboxes.astype(np.float32)[all_detected_bbox_label_mask.astype(np.bool8)],
                gt_bboxes[0]
            )
        })
        if self.use_color:
            ret_dict.update({"og_color": og_color})
        return ret_dict

    @staticmethod
    def _is_view_dep(utterance):
        """Check whether to augment based on nr3d utterance."""
        rels = [
            'front', 'behind', 'back', 'left', 'right', 'facing',
            'leftmost', 'rightmost', 'looking', 'across'
        ]
        words = set(utterance.split())
        return any(rel in words for rel in rels)

    @staticmethod
    def _find_rel(utterance):
        utterance = ' ' + utterance.replace(',', ' ,') + ' '
        relation = "none"
        sorted_rel_list = sorted(REL_ALIASES, key=len, reverse=True)
        for rel in sorted_rel_list:
            if ' ' + rel + ' ' in utterance:
                relation = REL_ALIASES[rel]
                break
        return relation

    @staticmethod
    def _augment_nr3d(utterance):
        """Check whether to augment based on nr3d utterance."""
        rels = [
            'front', 'behind', 'back', 'left', 'right', 'facing',
            'leftmost', 'rightmost', 'looking', 'across'
        ]
        augment = True
        for rel in rels:
            if ' ' + rel + ' ' in (utterance + ' '):
                augment = False
        return augment

    @staticmethod
    def _detector_succeeded(det_boxes, gt_box):
        ious, _ = _iou3d_par(
            box_cxcyczwhd_to_xyzxyz(torch.FloatTensor(det_boxes)),
            box_cxcyczwhd_to_xyzxyz(torch.FloatTensor(gt_box).unsqueeze(0))
        )  # (D, 1)
        return (ious > 0.25).any()

    @staticmethod
    def _hide_target(det_boxes, gt_box):
        ious, _ = _iou3d_par(
            box_cxcyczwhd_to_xyzxyz(torch.FloatTensor(det_boxes)),
            box_cxcyczwhd_to_xyzxyz(torch.FloatTensor(gt_box).unsqueeze(0))
        )  # (D, 1)
        return 1 - (ious > 0.25).any(1).int().numpy()

    @staticmethod
    def _hide_anchors(det_boxes, gt_box):
        ious, _ = _iou3d_par(
            box_cxcyczwhd_to_xyzxyz(torch.FloatTensor(det_boxes)),
            box_cxcyczwhd_to_xyzxyz(torch.FloatTensor(gt_box))
        )  # (D, len(anchors))
        return 1 - (ious > 0.25).any(1).int().numpy()

    def _visualize_scene(self, anno, point_cloud, og_color, all_bboxes):
        target_id = anno['target_id']
        distractor_ids = np.array(
            anno['distractor_ids']
            + [-1] * (10 - len(anno['distractor_ids']))
        ).astype(int)
        anchor_ids = np.array(
            anno['anchor_ids']
            + [-1] * (10 - len(anno['anchor_ids']))
        ).astype(int)
        point_cloud[:, 3:] = (og_color + self.mean_rgb) * 256

        all_boxes_points = box2points(all_bboxes[..., :6])

        target_box = all_boxes_points[target_id]
        anchors_boxes = all_boxes_points[[
            i.item() for i in anchor_ids if i != -1
        ]]
        distractors_boxes = all_boxes_points[[
            i.item() for i in distractor_ids if i != -1
        ]]

        wandb.log({
            "ground_truth_point_scene": wandb.Object3D({
                "type": "lidar/beta",
                "points": point_cloud,
                "boxes": np.array(
                    [  # target
                        {
                            "corners": target_box.tolist(),
                            "label": "target",
                            "color": [0, 255, 0]
                        }
                    ]
                    + [  # anchors
                        {
                            "corners": c.tolist(),
                            "label": "anchor",
                            "color": [0, 0, 255]
                        }
                        for c in anchors_boxes
                    ]
                    + [  # distractors
                        {
                            "corners": c.tolist(),
                            "label": "distractor",
                            "color": [0, 255, 255]
                        }
                        for c in distractors_boxes
                    ]
                    + [  # other
                        {
                            "corners": c.tolist(),
                            "label": "other",
                            "color": [255, 0, 0]
                        }
                        for i, c in enumerate(all_boxes_points)
                        if i not in (
                            [target_id]
                            + anchor_ids.tolist()
                            + distractor_ids.tolist()
                        )
                    ]
                )
            }),
            "utterance": wandb.Html(anno['utterance']),
        })

    def __len__(self):
        """Return number of utterances."""
        return len(self.annos)


def rot_x(pc, theta):
    """Rotate along x-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [1.0, 0, 0],
            [0, np.cos(theta), -np.sin(theta)],
            [0, np.sin(theta), np.cos(theta)]
        ]),
        pc.T
    ).T


def rot_y(pc, theta):
    """Rotate along y-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [np.cos(theta), 0, np.sin(theta)],
            [0, 1.0, 0],
            [-np.sin(theta), 0, np.cos(theta)]
        ]),
        pc.T
    ).T


def rot_z(pc, theta):
    """Rotate along z-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1.0]
        ]),
        pc.T
    ).T


def box2points(box):
    """Convert box center/hwd coordinates to vertices (8x3)."""
    x_min, y_min, z_min = (box[:, :3] - (box[:, 3:] / 2)).transpose(1, 0)
    x_max, y_max, z_max = (box[:, :3] + (box[:, 3:] / 2)).transpose(1, 0)

    return np.stack((
        np.concatenate((x_min[:, None], y_min[:, None], z_min[:, None]), 1),
        np.concatenate((x_min[:, None], y_max[:, None], z_min[:, None]), 1),
        np.concatenate((x_max[:, None], y_min[:, None], z_min[:, None]), 1),
        np.concatenate((x_max[:, None], y_max[:, None], z_min[:, None]), 1),
        np.concatenate((x_min[:, None], y_min[:, None], z_max[:, None]), 1),
        np.concatenate((x_min[:, None], y_max[:, None], z_max[:, None]), 1),
        np.concatenate((x_max[:, None], y_min[:, None], z_max[:, None]), 1),
        np.concatenate((x_max[:, None], y_max[:, None], z_max[:, None]), 1)
    ), axis=1)


def points2box(box):
    """Convert vertices (Nx8x3) to box center/hwd coordinates (Nx6)."""
    return np.concatenate((
        (box.min(1) + box.max(1)) / 2,
        box.max(1) - box.min(1)
    ), axis=1)
