"""Dataset and data loader for ReferIt3D."""

from copy import deepcopy
import csv
import h5py
import json
import multiprocessing as mp
import os

import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import RobertaTokenizerFast

from scannet.model_util_scannet import ScannetDatasetConfig
from scannet.scannet_utils import read_label_mapping
from .scannet_classes import REL_ALIASES, UNIQUE_VIEW_DEP_RELS, VIEW_DEP_RELS
from .scannet_dataset import unpickle_data, rot_z, get_positive_map
from sunrgbd.sunrgbd_utils import extract_pc_in_box3d
from utils import pc_util

import ipdb
st = ipdb.set_trace

NUM_CLASSES = 485
DC = ScannetDatasetConfig(NUM_CLASSES)
MAX_NUM_OBJ = 132
import wandb


class ReferIt3DDataset(Dataset):
    """Dataset utilities for ReferIt3D."""

    def __init__(self, anno_file, split='train', num_points=50000,
                 use_color=False, use_height=False, overfit=False,
                 detect_intermediate=False,
                 filter_relations=False, use_detected_boxes=False, visualize=False,
                 augment=True, use_oriented_boxes=False, rotate_pc=False,
                 use_multiview=False, train_viewpoint_module=False,
                 use_predicted_viewpoint=False, butd=False, butd_gt=False):
        """Initialize dataset (here for ReferIt3D utterances)."""
        self.anno_file = anno_file  # sr3d, nr3d or sr3d+
        self.split = split
        self.num_points = num_points
        self.use_color = use_color
        self.use_height = use_height
        self.overfit = overfit
        self.rotate_pc = rotate_pc
        self.use_predicted_viewpoint = use_predicted_viewpoint
        self.detect_intermediate = detect_intermediate
        self.filter_relations = filter_relations
        self.use_detected_boxes = use_detected_boxes
        self.use_oriented_boxes = use_oriented_boxes
        self.augment = augment
        self.use_multiview = use_multiview
        self.train_viewpoint_module = train_viewpoint_module
        self.rel_filter_list = ['on']
        self.data_path = './dataset/language_grounding/'
        self.visualize = visualize
        self.butd = butd
        self.butd_gt = butd_gt
        self.load_detected_boxes = False
        if self.visualize:
            wandb.init(project="vis", name="debug")

        if self.butd:
            self.use_detected_boxes = True
            self.load_detected_boxes = True
            self.augment = True

        if self.train_viewpoint_module:
            self.filter_relations = True
            self.augment = False
            self.rel_filter_list = UNIQUE_VIEW_DEP_RELS

        # do not augment
        if (self.use_detected_boxes or self.use_oriented_boxes or rotate_pc) and not self.butd:
            self.load_detected_boxes = True
            self.augment = False

        print('Loading %s files, take a breath!' % split)
        if split == 'train100':
            split = 'train'
        # elif split == 'val':
        #    split = 'test'
        _, self.scans = unpickle_data(f'{self.data_path}/%s_v2scans.pkl' % split)
        self.annos = self.load_annos()
        self.mean_rgb = np.array([109.8, 97.2, 83.8]) / 256
        self.label_map = read_label_mapping(
            'scannet/meta_data/scannetv2-labels.combined.tsv',
            label_from='raw_category',
            label_to='nyu40id' if NUM_CLASSES == 18 else 'id'
        )
        with open(self.data_path + 'extra/object_oriented_bboxes/object_oriented_bboxes_aligned_scans.json') as fid:
            self.oriented_bboxes_mapping = json.load(fid)
        self.multiview_path = os.path.join(f'{self.data_path}/scanrefer_2d_feats', "enet_feats_maxpool.hdf5")
        self.multiview_data = {}
        self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
        if os.path.exists('cls_results.json'):
            with open('cls_results.json') as fid:
                self.results = json.load(fid)

    def load_annos(self):
        """Load annotations."""
        split = self.split
        if split == 'train100':
            split = 'train'
        elif split == 'val':
            split = 'test'
        with open('%s/extra/%s_%s_scans.txt' % (self.data_path, self.anno_file, split)) as f:
            scan_ids = set(eval(f.read()))
        with open(self.data_path + 'refer_it_3d/%s.csv' % self.anno_file) as f:
            csv_reader = csv.reader(f)
            headers = next(csv_reader)
            headers = {header: h for h, header in enumerate(headers)}

            if self.anno_file.startswith('sr3d'):
                annos = [
                    {
                        'scan_id': line[headers['scan_id']],
                        'target_id': int(line[headers['target_id']]),
                        'distractor_ids': eval(line[headers['distractor_ids']]),
                        'utterance': line[headers['utterance']],
                        'target': line[headers['instance_type']],
                        'anchors': eval(line[headers['anchors_types']]),
                        'anchor_ids': eval(line[headers['anchor_ids']]),
                        'relation': self._find_rel(line[headers['utterance']]),
                    }
                    for line in csv_reader
                    if line[headers['scan_id']] in scan_ids
                    and
                    str(line[headers['mentions_target_class']]).lower() == 'true'
                    and
                    (
                        not self.filter_relations or
                        self._find_rel(line[headers['utterance']]) in self.rel_filter_list
                    )
                ]
            else:
                annos = [
                    {
                        'scan_id': line[headers['scan_id']],
                        'target_id': int(line[headers['target_id']]),
                        'utterance': line[headers['utterance']],
                        'target': self.scans[line[headers['scan_id']]].get_object_instance_label(int(line[headers['target_id']])),
                        'distractor_ids': [
                            ind
                            for ind in range(len(self.scans[line[headers['scan_id']]].three_d_objects))
                            if self.scans[line[headers['scan_id']]].get_object_instance_label(ind)
                            == self.scans[line[headers['scan_id']]].get_object_instance_label(int(line[headers['target_id']]))
                        ],
                        'anchor_ids': []
                    }
                    for line in csv_reader
                    if line[headers['scan_id']] in scan_ids
                    and
                    str(line[headers['mentions_target_class']]).lower() == 'true'
                    and
                    self.scans[line[headers['scan_id']]].get_object_instance_label(int(line[headers['target_id']]))
                    in line[headers['utterance']]
                ]
        if self.overfit:
            annos = annos[:256]
        elif self.split == 'train100':
            annos = annos[:100]
        return annos

    def _load_detected_boxes(
        self, split, scan_id,
        all_detected_bboxes, all_detected_bbox_label_mask,
        all_detected_bbox_class_names, detected_class_ids,
        augmentations
    ):
        detected_dict = np.load(
            f'{self.data_path}group_free_pred_bboxes_{split}/{scan_id}.npy',
            allow_pickle=True
        ).item()

        all_bboxes_ = np.array(detected_dict['box'])
        classes = detected_dict['class']
        cid = np.array([DC.nyu40id2class[
            self.label_map[c]] for c in detected_dict['class']
        ])

        all_bboxes_ = np.concatenate((
            (all_bboxes_[:, :3] + all_bboxes_[:, 3:]) * 0.5,
            all_bboxes_[:, 3:] - all_bboxes_[:, :3]
        ), 1)

        assert len(classes) < MAX_NUM_OBJ
        assert len(classes) == all_bboxes_.shape[0]

        num_objs = len(classes)
        all_detected_bboxes[:num_objs] = all_bboxes_
        all_detected_bbox_class_names[:num_objs] = classes
        all_detected_bbox_label_mask[:num_objs] = np.array([True] * num_objs)
        detected_class_ids[:num_objs] = cid
        if self.augment and self.split == 'train':
            all_detected_pts = box2points(all_detected_bboxes).reshape(-1, 3)
            all_detected_pts = rot_z(all_detected_pts, augmentations['theta'])
            if augmentations.get('yz_flip', False):
                all_detected_pts[:, 0] = -all_detected_pts[:, 0]
            if augmentations.get('xz_flip', False):
                all_detected_pts[:, 1] = -all_detected_pts[:, 1]
            all_detected_pts += augmentations['shift']
            all_detected_pts *= augmentations['scale']
            all_detected_bboxes = points2box(all_detected_pts.reshape(-1, 8, 3))
        return (
            all_detected_bboxes, all_detected_bbox_label_mask,
            all_detected_bbox_class_names, detected_class_ids
        )

    def __getitem__(self, index):
        """Get current batch for input index."""
        split = self.split
        if split == 'train100':
            split = 'train'

        anno = self.annos[index]
        rel_name = "none"
        if self.anno_file.startswith('sr3d'):
            rel_name = self._find_rel(anno['utterance'])

        # Pointcloud
        scan_id = anno['scan_id']
        scan = deepcopy(self.scans[scan_id])

        if self.use_multiview:
            # load multiview database
            pid = mp.current_process().pid
            if pid not in self.multiview_data:
                self.multiview_data[pid] = h5py.File(self.multiview_path, "r", libver="latest")

            multiview = self.multiview_data[pid][scan_id]
            scan.pc = np.concatenate([scan.pc, multiview], 1)

        if self.rotate_pc and rel_name in UNIQUE_VIEW_DEP_RELS:
            # load anchor
            anchor_id = anno['anchor_ids'][0]
            id_ = str(scan_id) + "_" + str(anchor_id)
            assert id_ in self.oriented_bboxes_mapping
            oriented_data = self.oriented_bboxes_mapping[id_]
            box = oriented_data['obj_bbox']
            orot = np.array(oriented_data['obj_rot'])

            # make pc face the object
            pc = scan.pc
            pc = pc - np.array(box[:3])[None]
            pc = np.concatenate([pc, np.ones((pc.shape[0], 1))], 1)
            pc = (np.linalg.inv(orot) @ pc.T).T[:, :3]

            scan.pc = pc

        eul = np.array([0, 0, 0])

        if self.train_viewpoint_module and self.anno_file.startswith('sr3d'):
            # load anchor
            anchor_id = anno['anchor_ids'][0]
            id_ = str(scan_id) + "_" + str(anchor_id)
            assert id_ in self.oriented_bboxes_mapping
            oriented_data = self.oriented_bboxes_mapping[id_]
            box = oriented_data['obj_bbox']
            orot = np.array(oriented_data['obj_rot'])

            eul = np.array(self._rotm2eul(orot))

            # update utterance to have only the subpart of sent
            # which mentions viewpoint
            sent = anno['utterance']
            if ',' in sent:
                anno['utterance'] = sent.split(',')[0]
            elif 'front of' in sent:
                anno['utterance'] = 'front of ' + sent.split('front of')[1]
            elif 'back of' in sent:
                anno['utterance'] = 'back of ' + sent.split('back of')[1]
            else:
                assert False

            # target is basically anchor in this case
            anno['target_id'] = anno['anchor_ids'][0]
            anno['target'] = anno['anchors'][0]

        augmentations = {}
        if self.split == 'train' and self.augment:
            augment_nr3d = (
                self.anno_file.startswith('nr3d')
                and self._augment_nr3d(anno['utterance'])
            )
            augment_sr3d = (
                self.anno_file.startswith('sr3d')
                and rel_name not in VIEW_DEP_RELS
            )
            # Rotate/flip only if we don't have a view_dep sentence
            if augment_nr3d or augment_sr3d:
                theta = (2*np.random.rand() - 1) * 180
                # Flipping along the YZ plane
                augmentations['yz_flip'] = np.random.random() > 0.5
                if augmentations['yz_flip']:
                    scan.pc[:, 0] = -scan.pc[:, 0]
                # Flipping along the XZ plane
                augmentations['xz_flip'] = np.random.random() > 0.5
                if augmentations['xz_flip']:
                    scan.pc[:, 1] = -scan.pc[:, 1]
            else:
                theta = (2*np.random.rand() - 1) * 5
            augmentations['theta'] = theta
            scan.pc[:, :3] = rot_z(scan.pc[:, :3], theta)
            # Add noise
            noise = np.random.rand(len(scan.pc), 3) * 5e-3
            augmentations['noise'] = noise
            scan.pc[:, :3] = scan.pc[:, :3] + noise
            # Translate/shift
            augmentations['shift'] = np.random.random((3,))[None, :] - 0.5
            scan.pc[:, :3] += augmentations['shift']
            # Scale
            augmentations['scale'] = 0.95 + 0.1*np.random.random((3,))[None, :]
            scan.pc[:, :3] *= augmentations['scale']
        point_cloud = scan.pc

        # Objects to keep
        keep_ = np.array([
            self.label_map[
                scan.get_object_instance_label(ind)
            ] in DC.nyu40id2class
            for ind in range(len(scan.three_d_objects))
        ])[:MAX_NUM_OBJ]
        keep = np.array([False] * MAX_NUM_OBJ)
        keep[:len(keep_)] = keep_

        if not self.use_detected_boxes:
            all_bbox_class_names = ["none"] * MAX_NUM_OBJ
            for k, kept in enumerate(keep):
                if not kept:
                    continue
                all_bbox_class_names[k] = scan.get_object_instance_label(k)

            cid = np.array([
                DC.nyu40id2class[self.label_map[scan.get_object_instance_label(k)]]
                for k, kept in enumerate(keep) if kept
            ])

            class_ids = np.zeros((MAX_NUM_OBJ,))
            class_ids[keep] = cid

        # Color
        if self.use_color:
            point_cloud = np.concatenate((
                point_cloud,
                scan.color - self.mean_rgb
            ), 1)
            if self.split == 'train' and self.augment:
                point_cloud[:, -3:] *= (
                    0.98 + 0.04*np.random.random((len(point_cloud), 3))
                )

        # Height
        if self.use_height:
            floor_height = np.percentile(point_cloud[:, 2], 0.99)
            height = np.expand_dims(point_cloud[:, 2] - floor_height, 1)
            point_cloud = np.concatenate([point_cloud, height], 1)

        # "Target" boxes: append anchors if Sr3D and training
        bboxes = np.zeros((MAX_NUM_OBJ, 6))
        tids = [anno['target_id']]  # + anno['distractor_ids']
        if self.anno_file.startswith('sr3d') and self.detect_intermediate:
            tids += anno['anchor_ids']

        bboxes[:len(tids)] = np.stack([
            scan.get_object_bbox(
                tid).reshape(-1) for tid in tids
        ])
        bboxes = np.concatenate((
            (bboxes[:, :3] + bboxes[:, 3:]) * 0.5,
            bboxes[:, 3:] - bboxes[:, :3]
        ), 1)

        # distractors and anchors
        distractor_ids = np.array(
            anno['distractor_ids']
            + [-1] * (10 - len(anno['distractor_ids']))
        ).astype(int)
        anchor_ids = np.array(
            anno['anchor_ids']
            + [-1] * (10 - len(anno['anchor_ids']))
        ).astype(int)

        if not self.use_detected_boxes:
            if self.use_oriented_boxes:
                all_bboxes = np.zeros((MAX_NUM_OBJ, 9))

                # end_points
                all_bboxes_ = np.stack([
                    self._get_oriented_bbox_kitti(scan, scan_id, k).reshape(-1)
                    for k, kept in enumerate(keep) if kept
                ])
                # cx, cy, cz, w, h, d
                all_bboxes_ = np.concatenate((
                    (all_bboxes_[:, :3] + all_bboxes_[:, 3:6]) * 0.5,
                    all_bboxes_[:, 3:6] - all_bboxes_[:, :3],
                    all_bboxes_[:, 6:]
                ), 1)
                all_bboxes[keep] = all_bboxes_
            else:
                all_bboxes = np.zeros((MAX_NUM_OBJ, 6))

                # end_points
                all_bboxes_ = np.stack([
                    scan.get_object_bbox(k).reshape(-1)
                    for k, kept in enumerate(keep) if kept
                ])
                # cx, cy, cz, w, h, d
                all_bboxes_ = np.concatenate((
                    (all_bboxes_[:, :3] + all_bboxes_[:, 3:]) * 0.5,
                    all_bboxes_[:, 3:] - all_bboxes_[:, :3]
                ), 1)
                all_bboxes[keep] = all_bboxes_

            all_bbox_label_mask = keep

        if self.split == 'train' and self.augment:  # jitter boxes
            bboxes[:len(tids)] *= (0.95 + 0.1*np.random.random((len(tids), 6)))
        bboxes[len(tids):, :3] = 1000

        # Target points
        point_obj_mask = np.zeros(len(scan.pc))
        for t, tid in enumerate(tids):
            point_obj_mask[scan.three_d_objects[tid]['points']] = t + 1

        # Per-instance points
        point_instance_label = (point_obj_mask > 0).astype(np.int64) - 1

        # Sample points
        if self.split != 'train':
            np.random.seed(1184)
        choices = np.random.choice(
            point_cloud.shape[0],
            self.num_points,
            replace=len(point_cloud) < self.num_points
        )
        point_cloud = point_cloud[choices]
        point_obj_mask = point_obj_mask[choices]
        point_instance_label = point_instance_label[choices]

        # Token start-end span in characters
        caption = ' '.join(anno['utterance'].replace(',', ' ,').split())
        caption = ' ' + caption + ' '
        tokens_positive = np.zeros((MAX_NUM_OBJ, 2))
        cat_names = [anno['target']]  # * (1 + len(anno['distractor_ids']))
        if self.anno_file.startswith('sr3d') and self.detect_intermediate:
            cat_names += anno['anchors']
        for c, cat_name in enumerate(cat_names):
            start_span = caption.find(' ' + cat_name + ' ')
            len_ = len(cat_name)
            if start_span < 0:
                start_span = caption.find(' ' + cat_name)
                len_ = len(caption[start_span+1:].split()[0])
            if start_span < 0:
                start_span = caption.find(cat_name)
                orig_start_span = start_span
                while caption[start_span - 1] != ' ':
                    start_span -= 1
                len_ = len(cat_name) + orig_start_span - start_span
                while caption[len_ + start_span] != ' ':
                    len_ += 1
            end_span = start_span + len_
            tokens_positive[c][0] = start_span
            tokens_positive[c][1] = end_span
        # tokens_positive[1:len(anno['distractor_ids']) + 1, 0] = 0
        # tokens_positive[1:len(anno['distractor_ids']) + 1, 1] = len(
        #    anno['utterance'].replace(',', ' ,').split()
        # )
        # Positive map (for soft token prediction)
        tokenized = self.tokenizer.batch_encode_plus(
            [' '.join(anno['utterance'].replace(',', ' ,').split())],
            padding="longest", return_tensors="pt"
        )
        positive_map = np.zeros((MAX_NUM_OBJ, 256))
        gt_map = get_positive_map(tokenized, tokens_positive[:len(cat_names)])
        positive_map[:len(cat_names)] = gt_map
        # positive_map[1:len(anno['distractor_ids']) + 1] = 0

        # Fill return dict
        # '''
        box_label_mask = np.zeros(MAX_NUM_OBJ)
        box_label_mask[:len(tids)] = 1

        # load detected boxes
        all_detected_bboxes = np.zeros((MAX_NUM_OBJ, 6))
        all_detected_bbox_label_mask = np.array([False] * MAX_NUM_OBJ)
        all_detected_bbox_class_names = ["none"] * MAX_NUM_OBJ
        detected_class_ids = np.zeros((MAX_NUM_OBJ,))
        if self.load_detected_boxes:
            (
                all_detected_bboxes, all_detected_bbox_label_mask,
                all_detected_bbox_class_names, detected_class_ids
            ) = self._load_detected_boxes(
                split, scan_id,
                all_detected_bboxes, all_detected_bbox_label_mask,
                all_detected_bbox_class_names, detected_class_ids,
                augmentations
            )

        if self.use_detected_boxes:
            assert self.load_detected_boxes
            all_bboxes = all_detected_bboxes
            all_bbox_label_mask = all_detected_bbox_label_mask
            all_bbox_class_names = all_detected_bbox_class_names
            class_ids = detected_class_ids

        if self.butd_gt:
            all_detected_bboxes = all_bboxes
            all_detected_bbox_label_mask = all_bbox_label_mask
            all_detected_bbox_class_names = all_bbox_class_names
            detected_class_ids = class_ids

        points_to_boxes = np.zeros((len(point_cloud), 7))
        if self.butd or self.butd_gt:
            # find points belonging to each detected box
            detected_box_points = pc_util.box2points(all_detected_bboxes)
            for i, obj in enumerate(detected_box_points[all_detected_bbox_label_mask]):
                _, indices = extract_pc_in_box3d(
                    point_cloud[:, :3], obj)
                points_to_boxes[indices, -7:-1] = all_detected_bboxes[i][None]
                points_to_boxes[indices, -1] = detected_class_ids[i][None] / NUM_CLASSES
        if self.split == 'train' and self.augment:
            all_bboxes *= (0.95 + 0.1*np.random.random((len(all_bboxes), 6)))

        all_logits = np.zeros((MAX_NUM_OBJ, NUM_CLASSES))
        all_logits[torch.arange(
            len(class_ids)), class_ids.astype(np.int64)] = 1

        if self.split == 'train' and self.augment:
            noise_logits = np.random.rand(*all_logits.shape)
            noise_logits /= noise_logits.sum(1)[:, None]
            all_logits = (all_logits + noise_logits) / 2

        if self.visualize:
            target_id = anno['target_id']
            point_cloud[:, 3:] = (point_cloud[:, 3:] + self.mean_rgb) * 256

            all_boxes_points = pc_util.box2points(all_bboxes[..., :6])
            if all_bboxes.shape[-1] == 9:
                eul = all_bboxes[..., 6:]
                mat = self._eul2rotm_py(eul)
                all_boxes_points = self._oriented_corners(
                    all_boxes_points, all_bboxes[..., :3], mat)

            target_box = all_boxes_points[target_id]
            anchors_boxes = all_boxes_points[[
                i.item() for i in anchor_ids if i != -1
            ]]
            distractors_boxes = all_boxes_points[[
                i.item() for i in distractor_ids if i != -1
            ]]
            num_distr = (distractor_ids > -1).sum()

            names = (
                [all_bbox_class_names[target_id]]
                + [
                    all_bbox_class_names[i]
                    for i in distractor_ids if i != -1
                ]
                + [
                    all_bbox_class_names[i]
                    for i in anchor_ids if i != -1
                ]
            )
            wandb.log({
                "ground_truth_point_scene": wandb.Object3D(
                    {
                        "type": "lidar/beta",
                        "points": point_cloud,
                        "boxes": np.array(
                        [  # target
                            {
                                "corners": target_box.tolist(),
                                "label": "gt: " + names[0],
                                "color": [0, 255, 0]
                            }
                        ]
                        + [  # anchors
                            {
                                "corners": c.tolist(),
                                "label": "anchor: " + names[1 + num_distr + i],
                                "color": [0, 0, 255]
                            }
                            for i, c in enumerate(anchors_boxes)
                        ]
                        + [  # distractors
                            {
                                "corners": c.tolist(),
                                "label": "distractor",
                                "color": [0, 255, 255]
                            }
                            for c in distractors_boxes
                        ]
                        + [
                            {
                                "corners": c.tolist(),
                                "label": "butd",
                                "color": [255, 0, 0]
                            }
                            for c in pc_util.box2points(all_detected_bboxes)
                        ]
                    )
                    }
                ),
                "utterance": wandb.Html(anno['utterance']),
            })

        ret_dict = {
            'box_label_mask': box_label_mask.astype(np.float32),
            'center_label': bboxes[:, :3].astype(np.float32),
            'heading_class_label': np.zeros((MAX_NUM_OBJ,)).astype(np.int64),
            'heading_residual_label': np.zeros((MAX_NUM_OBJ,)),
            'point_obj_mask': point_obj_mask.astype(np.int64),
            'point_instance_label': point_instance_label.astype(np.int64),
            'sem_cls_label': class_ids.astype(np.int64),
            'size_gts': bboxes[:, 3:].astype(np.float32),
            'size_class_label': np.zeros(MAX_NUM_OBJ).astype(np.int64),
            'size_residual_label': bboxes[:, 3:].astype(np.float32)
        }
        ret_dict.update({
            "scan_ids": anno['scan_id'],
            "point_clouds": point_cloud.astype(np.float32),
            "utterances": ' '.join(anno['utterance'].replace(',', ' ,').split()),
            "tokens_positive": tokens_positive.astype(np.int64),
            "positive_map": positive_map.astype(np.float32),
            "is_grounding": 1,
            "num_point_class": len(self.scans[anno['scan_id']].pc) // 50000,
            "relation": rel_name,
            "target_name": cat_names[0],
            "target_id": anno['target_id'],
            "all_bboxes": all_bboxes.astype(np.float32),
            "all_bbox_label_mask": all_bbox_label_mask.astype(np.bool8),
            "all_bbox_class_names": all_bbox_class_names,
            "distractor_ids": distractor_ids,
            "anchor_ids": anchor_ids,
            "all_logits": all_logits.astype(np.float32),
            "eul": eul.astype(np.float32),
            "all_detected_boxes": all_detected_bboxes.astype(np.float32),
            "all_detected_bbox_label_mask": all_detected_bbox_label_mask.astype(np.bool8),
            "all_detected_class_ids": detected_class_ids.astype(np.int64),
            "points_to_boxes": points_to_boxes.astype(np.float32)
        })
        if self.use_color:
            ret_dict.update({"og_color": scan.color[choices]})
        return ret_dict

    @staticmethod
    def _find_rel(utterance):
        utterance = ' ' + utterance.replace(',', ' ,') + ' '
        relation = "none"
        sorted_rel_list = sorted(REL_ALIASES, key=len, reverse=True)
        for rel in sorted_rel_list:
            if ' ' + rel + ' ' in utterance:
                relation = REL_ALIASES[rel]
                break
        return relation

    @staticmethod
    def _augment_nr3d(utterance):
        rels = ['front', 'behind', 'back', 'left', 'right']
        augment = True
        for rel in rels:
            if ' ' + rel + ' ' in (utterance + ' '):
                augment = False
        return augment

    @staticmethod
    def _eul2rotm_py(eul):
        # inputs are shaped B
        # this func is copied from matlab
        # R = [  cy*cz   sy*sx*cz-sz*cx    sy*cx*cz+sz*sx
        #        cy*sz   sy*sx*sz+cz*cx    sy*cx*sz-cz*sx
        #        -sy            cy*sx             cy*cx]
        rx, ry, rz = eul.T
        rx = rx[:, np.newaxis]
        ry = ry[:, np.newaxis]
        rz = rz[:, np.newaxis]
        # these are B x 1
        sinz = np.sin(rz)
        siny = np.sin(ry)
        sinx = np.sin(rx)
        cosz = np.cos(rz)
        cosy = np.cos(ry)
        cosx = np.cos(rx)
        r11 = cosy*cosz
        r12 = sinx*siny*cosz - cosx*sinz
        r13 = cosx*siny*cosz + sinx*sinz
        r21 = cosy*sinz
        r22 = sinx*siny*sinz + cosx*cosz
        r23 = cosx*siny*sinz - sinx*cosz
        r31 = -siny
        r32 = sinx*cosy
        r33 = cosx*cosy
        r1 = np.stack([r11, r12, r13], axis=2)
        r2 = np.stack([r21, r22, r23], axis=2)
        r3 = np.stack([r31, r32, r33], axis=2)
        r = np.concatenate([r1, r2, r3], axis=1)
        return r

    @staticmethod
    def _oriented_corners(axis_aligned_corners, centers, rot):

        # calculate the relative coordinates to the center of axis_aligned_bbox
        axis_aligned_corners = axis_aligned_corners - centers[:, None]

        # transform the points also plus the center coordinates.
        axis_aligned_corners = np.concatenate([
            axis_aligned_corners,
            np.ones((
                axis_aligned_corners.shape[0], axis_aligned_corners.shape[1], 1))
        ], 2)
        rotation = np.repeat(np.eye(4)[None], rot.shape[0], axis=0)
        rotation[:, :3, :3] = rot.copy()
        rotation[:, :3, 3] = centers

        corners = torch.bmm(
            torch.from_numpy(rotation),
            torch.from_numpy(axis_aligned_corners.transpose(0, 2, 1))
            ).numpy().transpose(0, 2, 1)[:, :, 0:3]

        return corners

    def _get_oriented_bbox_kitti(self, scan, scan_id, object_id):
        id = str(scan_id) + "_" + str(object_id)
        if id in self.oriented_bboxes_mapping:
            oriented_data = self.oriented_bboxes_mapping[id]
            # cx, cy, cz, w, h, d
            box = oriented_data['obj_bbox']
            box_corners = self._box_cxcyczwhd_to_xyzxyz(box)
            orot = np.array(oriented_data['obj_rot'])
            eul = self._rotm2eul(orot)
            obox = np.array(box_corners + eul)
        else:
            box = scan.get_object_bbox(object_id)
            obox = np.zeros((9))
            obox[:6] = box
        return obox

    @staticmethod
    def _box_cxcyczwhd_to_xyzxyz(x):
        x_c, y_c, z_c, w, h, d = x
        assert w > 0
        assert h > 0
        assert d > 0
        b = [
                x_c - 0.5 * w, y_c - 0.5 * h, z_c - 0.5 * d,
                x_c + 0.5 * w, y_c + 0.5 * h, z_c + 0.5 * d
            ]
        return b

    @staticmethod
    def _rotm2eul(r):
        # r is 3x3, or 4x4
        r00 = r[0, 0]
        r10 = r[1, 0]
        r11 = r[1, 1]
        r12 = r[1, 2]
        r20 = r[2, 0]
        r21 = r[2, 1]
        r22 = r[2, 2]

        sy = np.sqrt(r00 * r00 + r10 * r10)

        cond = (sy > 1e-6)
        rx = np.where(cond, np.arctan2(r21, r22), np.arctan2(-r12, r11))
        ry = np.where(cond, np.arctan2(-r20, sy), np.arctan2(-r20, sy))
        rz = np.where(cond, np.arctan2(r10, r00), np.zeros_like(r20))

        return [rx, ry, rz]

    @staticmethod
    def _rotz(t):
        """Rotation about the z-axis."""
        c = np.cos(t)
        s = np.sin(t)
        return np.array([[c, -s, 0],
                        [s, c, 0],
                        [0, 0, 1]])

    def __len__(self):
        """Return number of utterances."""
        return len(self.annos)


def box2points(box):
    """Convert box center/hwd coordinates to vertices (8x3)."""
    x_min, y_min, z_min = (box[:, :3] - (box[:, 3:] / 2)).transpose(1, 0)
    x_max, y_max, z_max = (box[:, :3] + (box[:, 3:] / 2)).transpose(1, 0)

    return np.stack((
        np.concatenate((x_min[:, None], y_min[:, None], z_min[:, None]), 1),
        np.concatenate((x_min[:, None], y_max[:, None], z_min[:, None]), 1),
        np.concatenate((x_max[:, None], y_min[:, None], z_min[:, None]), 1),
        np.concatenate((x_max[:, None], y_max[:, None], z_min[:, None]), 1),
        np.concatenate((x_min[:, None], y_min[:, None], z_max[:, None]), 1),
        np.concatenate((x_min[:, None], y_max[:, None], z_max[:, None]), 1),
        np.concatenate((x_max[:, None], y_min[:, None], z_max[:, None]), 1),
        np.concatenate((x_max[:, None], y_max[:, None], z_max[:, None]), 1)
    ), axis=1)


def points2box(box):
    """Convert vertices (Nx8x3) to box center/hwd coordinates (Nx6)."""
    return np.concatenate((
        (box.min(1) + box.max(1)) / 2,
        box.max(1) - box.min(1)
    ), axis=1)


if __name__ == '__main__':
    RD = ReferIt3DDataset(
        'sr3d',
        split='train',
        use_color=True,
        butd=True
    )
    for i in range(50):
        RD.__getitem__(i)
