"""Dataset and data loader for ScanNet."""

from copy import deepcopy
import h5py
from six.moves import cPickle
import multiprocessing as mp
import os
import random

import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import RobertaTokenizerFast

from scannet.model_util_scannet import ScannetDatasetConfig
from scannet.scannet_utils import read_label_mapping
from src.visual_data_handlers import Scan, ScanNetMappings
from sunrgbd.sunrgbd_utils import extract_pc_in_box3d
from utils import pc_util

import ipdb
st = ipdb.set_trace

NUM_CLASSES = 485
DC = ScannetDatasetConfig(NUM_CLASSES, agnostic=True)
MAX_NUM_OBJ = 132


class ScanNetDataset(Dataset):
    """Dataset utilities for ScanNet."""

    def __init__(self, split='train', num_points=50000,
                 use_color=False, use_height=False, overfit=False,
                 agnostic=False, use_multiview=False, butd=False, butd_gt=False):
        """Initialize dataset."""
        self.split = split
        self.num_points = num_points
        self.use_color = use_color
        self.use_height = use_height
        self.overfit = overfit
        self.agnostic = agnostic
        self.use_multiview = use_multiview
        self.butd = butd
        self.butd_gt = butd_gt
        self.load_detected_boxes = False
        self.augment = True
        if self.butd:
            self.load_detected_boxes = True
        self.data_path = './dataset/language_grounding/'
        print('Loading %s files, take a breath!' % split)
        split = 'val' if split != 'train' else 'train'
        if not os.path.exists(f'{self.data_path}/%s_v2scans.pkl' % split):
            save_data(f'{self.data_path}/%s_v2scans.pkl' % split, split)
        _, self.scans = unpickle_data(
            f'{self.data_path}/%s_v2scans.pkl' % split
        )
        self.mean_rgb = np.array([109.8, 97.2, 83.8]) / 256
        self.label_map = read_label_mapping(
            'scannet/meta_data/scannetv2-labels.combined.tsv',
            label_from='raw_category',
            label_to='nyu40id' if NUM_CLASSES == 18 else 'id'
        )
        self.scan_ids = self.load_annos()
        self.multiview_path = os.path.join(f'{self.data_path}/scanrefer_2d_feats', "enet_feats_maxpool.hdf5")
        self.multiview_data = {}
        self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

    def load_annos(self):
        """Load annotations."""
        split = 'train' if self.split == 'train' else 'val'
        with open('scannet/meta_data/scannetv2_%s.txt' % split) as f:
            scan_ids = [line.rstrip() for line in f]
        keep_ids = []
        for scan_id in scan_ids:
            scan = self.scans[scan_id]
            keep = np.array([
                self.label_map[
                    scan.get_object_instance_label(ind)
                ] in DC.nyu40id2class
                for ind in range(len(scan.three_d_objects))
            ])
            if keep.any():
                keep_ids.append(scan_id)
        scan_ids = keep_ids
        if self.overfit:
            scan_ids = scan_ids[:50]
        return scan_ids

    @staticmethod
    def _augment_pc(scan):
        """Apply augmentations on point cloud."""
        augmentations = {}
        # Flipping along the YZ plane
        augmentations['yz_flip'] = np.random.random() > 0.5
        if augmentations['yz_flip']:
            scan.pc[:, 0] = -scan.pc[:, 0]
        # Flipping along the XZ plane
        augmentations['xz_flip'] = np.random.random() > 0.5
        if augmentations['xz_flip']:
            scan.pc[:, 1] = -scan.pc[:, 1]
        theta = (2*np.random.rand() - 1) * 180
        augmentations['theta'] = theta
        scan.pc[:, :3] = rot_z(scan.pc[:, :3], theta)
        # Add noise
        noise = np.random.rand(len(scan.pc), 3) * 5e-3
        augmentations['noise'] = noise
        scan.pc[:, :3] = scan.pc[:, :3] + noise
        # Translate/shift
        augmentations['shift'] = np.random.random((3,))[None, :] - 0.5
        scan.pc[:, :3] += augmentations['shift']
        # Scale
        augmentations['scale'] = 0.95 + 0.1*np.random.random((3,))[None, :]
        scan.pc[:, :3] *= augmentations['scale']
        return scan, augmentations

    @staticmethod
    def _augment_color(color_rgb):
        """Apply color augmentations."""
        return color_rgb * 0.98 + 0.04*np.random.random((len(color_rgb), 3))

    def _get_pc(self, scan):
        """Get point cloud for given scan and options."""
        augmentations = {}
        if self.split == 'train' and self.augment:
            scan, augmentations = self._augment_pc(scan)
        point_cloud = scan.pc

        # 2d feats
        if self.use_multiview:
            # load multiview database
            pid = mp.current_process().pid
            if pid not in self.multiview_data:
                self.multiview_data[pid] = h5py.File(self.multiview_path, "r", libver="latest")

            multiview = self.multiview_data[pid][scan.scan_id]
            point_cloud = np.concatenate([point_cloud, multiview], 1)

        # Color
        if self.use_color:
            point_cloud = np.concatenate((
                point_cloud,
                scan.color - self.mean_rgb
            ), 1)
            if self.split == 'train' and self.augment:
                point_cloud[:, -3:] = self._augment_color(point_cloud[:, -3:])

        # Height
        if self.use_height:
            floor_height = np.percentile(point_cloud[:, 2], 0.99)
            height = np.expand_dims(point_cloud[:, 2] - floor_height, 1)
            point_cloud = np.concatenate([point_cloud, height], 1)

        return point_cloud, augmentations

    def _load_detected_boxes(
        self, split, scan_id,
        all_detected_bboxes, all_detected_bbox_label_mask,
        all_detected_bbox_class_names, detected_class_ids,
        augmentations
    ):
        detected_dict = np.load(
            f'{self.data_path}group_free_pred_bboxes_{split}/{scan_id}.npy',
            allow_pickle=True
        ).item()

        all_bboxes_ = np.array(detected_dict['box'])
        classes = detected_dict['class']
        cid = np.array([DC.nyu40id2class[
            self.label_map[c]] for c in detected_dict['class']
        ])

        all_bboxes_ = np.concatenate((
            (all_bboxes_[:, :3] + all_bboxes_[:, 3:]) * 0.5,
            all_bboxes_[:, 3:] - all_bboxes_[:, :3]
        ), 1)

        assert len(classes) < MAX_NUM_OBJ
        assert len(classes) == all_bboxes_.shape[0]

        num_objs = len(classes)
        all_detected_bboxes[:num_objs] = all_bboxes_
        all_detected_bbox_class_names[:num_objs] = classes
        all_detected_bbox_label_mask[:num_objs] = np.array([True] * num_objs)
        detected_class_ids[:num_objs] = cid
        if self.augment and self.split == 'train':
            all_detected_pts = box2points(all_detected_bboxes).reshape(-1, 3)
            all_detected_pts = rot_z(all_detected_pts, augmentations['theta'])
            if augmentations.get('yz_flip', False):
                all_detected_pts[:, 0] = -all_detected_pts[:, 0]
            if augmentations.get('xz_flip', False):
                all_detected_pts[:, 1] = -all_detected_pts[:, 1]
            all_detected_pts += augmentations['shift']
            all_detected_pts *= augmentations['scale']
            all_detected_bboxes = points2box(all_detected_pts.reshape(-1, 8, 3))
        return (
            all_detected_bboxes, all_detected_bbox_label_mask,
            all_detected_bbox_class_names, detected_class_ids
        )

    def __getitem__(self, index):
        """Get current batch for input index."""
        # Pointcloud
        scan_id = self.scan_ids[index]
        scan = deepcopy(self.scans[scan_id])
        point_cloud, augmentations = self._get_pc(scan)
        # sample classes to keep
        sampled_classes = set([self.label_map[
                scan.get_object_instance_label(ind)
            ]
            for ind in range(len(scan.three_d_objects))])

        sampled_classes = list(sampled_classes & set(DC.nyu40id2class))

        # sample 10 classes
        if self.split != 'train':
            random.seed(1184)
        if len(sampled_classes) > 10:
            sampled_classes = random.sample(sampled_classes, 8)

        # Objects to keep
        keep_ = np.array([
            self.label_map[
                scan.get_object_instance_label(ind)
            ] in sampled_classes
            for ind in range(len(scan.three_d_objects))
        ])[:MAX_NUM_OBJ]
        keep = np.array([False] * MAX_NUM_OBJ)
        keep[:len(keep_)] = keep_

        # Bboxes
        t_bboxes = np.zeros((MAX_NUM_OBJ, 6))
        t_bboxes[:, :3] = 1000
        bboxes = np.stack([
            scan.get_object_bbox(k).reshape(-1)
            for k, kept in enumerate(keep) if kept
        ])
        bboxes = np.concatenate((
            (bboxes[:, :3] + bboxes[:, 3:]) * 0.5,
            bboxes[:, 3:] - bboxes[:, :3]
        ), 1)
        if self.split == 'train' and self.augment:  # jitter boxes
            bboxes *= (0.95 + 0.1*np.random.random((len(bboxes), 6)))
        t_bboxes[:len(bboxes)] = bboxes

        # Objectness points
        point_obj_mask = np.zeros(len(scan.pc))
        for k, kept in enumerate(keep):
            if kept:
                point_obj_mask[scan.three_d_objects[k]['points']] = 1

        # Per-instance points
        point_instance_label = np.zeros(len(scan.pc))
        t = 0
        for k, kept in enumerate(keep):
            if kept:
                t += 1
                point_instance_label[scan.three_d_objects[k]['points']] = t
        point_instance_label -= 1

        # Sample points
        if self.split != 'train':
            np.random.seed(1184)
        choices = np.random.choice(
            point_cloud.shape[0],
            self.num_points,
            replace=len(point_cloud) < self.num_points
        )
        point_cloud = point_cloud[choices]
        point_obj_mask = point_obj_mask[choices]
        point_instance_label = point_instance_label[choices]

        # Create utterance
        cat_names = [
            DC.class2type[DC.nyu40id2class[
                self.label_map[scan.get_object_instance_label(k)]
            ]]
            for k, kept in enumerate(keep) if kept
        ]
        neg_names = []
        while len(neg_names) < 4:
            _ind = np.random.randint(0, len(DC.class2type))
            if DC.class2type[_ind] not in neg_names + cat_names:
                neg_names.append(DC.class2type[_ind])
        mixed_names = sorted(list(set(cat_names + neg_names)))
        if self.split != 'train':
            random.seed(1184)
        random.shuffle(mixed_names)
        utterance = ' . '.join(mixed_names) + ' .'

        # Token start-end span in characters
        caption = ' ' + utterance + ' '
        tokens_positive = np.zeros((MAX_NUM_OBJ, 2))
        for c, cat_name in enumerate(cat_names):
            start_span = caption.find(' ' + cat_name + ' ')
            end_span = start_span + len(cat_name)
            tokens_positive[c][0] = start_span
            tokens_positive[c][1] = end_span
        # Positive map (for soft token prediction)
        tokenized = self.tokenizer.batch_encode_plus(
            [utterance],
            padding="longest", return_tensors="pt"
        )
        positive_map = np.zeros((MAX_NUM_OBJ, 256))
        gt_map = get_positive_map(tokenized, tokens_positive[:len(cat_names)])
        positive_map[:len(cat_names)] = gt_map

        # Fill ret_dict for losses
        cid = np.array([
            DC.nyu40id2class[self.label_map[scan.get_object_instance_label(k)]]
            for k, kept in enumerate(keep) if kept
        ])
        class_ids = np.zeros((MAX_NUM_OBJ,))
        class_ids[:len(cid)] = cid
        size_residuals = np.copy(t_bboxes[:, 3:])
        size_residuals[:len(cid)] -= DC.mean_size_arr[cid, :]
        box_label_mask = np.zeros((MAX_NUM_OBJ))
        box_label_mask[:len(bboxes)] = 1
        all_bboxes = np.stack([
                scan.get_object_bbox(k).reshape(-1)
                if k < len(scan.three_d_objects)
                else np.zeros(6)
                for k in range(MAX_NUM_OBJ)
            ]).astype(np.float32)
        all_bbox_class_names = ["none"] * MAX_NUM_OBJ
        for k, kept in enumerate(keep):
            if not kept:
                continue
            all_bbox_class_names[k] = scan.get_object_instance_label(k)
        all_bbox_label_mask = keep

        all_logits = np.zeros((MAX_NUM_OBJ, NUM_CLASSES))
        all_logits[torch.arange(len(class_ids)), class_ids.astype(np.int64)] = 1
        if self.split == 'train' and self.augment:
            noise_logits = np.random.rand(*all_logits.shape)
            noise_logits /= noise_logits.sum(1)[:, None]
            all_logits = (all_logits + noise_logits) / 2

        # load detected boxes
        all_detected_bboxes = np.zeros((MAX_NUM_OBJ, 6))
        all_detected_bbox_label_mask = np.array([False] * MAX_NUM_OBJ)
        all_detected_bbox_class_names = ["none"] * MAX_NUM_OBJ
        detected_class_ids = np.zeros((MAX_NUM_OBJ,))
        if self.load_detected_boxes:
            (
                all_detected_bboxes, all_detected_bbox_label_mask,
                all_detected_bbox_class_names, detected_class_ids
            ) = self._load_detected_boxes(
                self.split, scan_id,
                all_detected_bboxes, all_detected_bbox_label_mask,
                all_detected_bbox_class_names, detected_class_ids,
                augmentations
            )

        if self.butd_gt:
            all_detected_bboxes = all_bboxes
            all_detected_bbox_label_mask = all_bbox_label_mask
            all_detected_bbox_class_names = all_bbox_class_names
            detected_class_ids = class_ids

        points_to_boxes = np.zeros((len(point_cloud), 7))
        if self.butd or self.butd_gt:
            # find points belonging to each detected box
            detected_box_points = pc_util.box2points(all_detected_bboxes)
            for i, obj in enumerate(detected_box_points[all_detected_bbox_label_mask]):
                _, indices = extract_pc_in_box3d(
                    point_cloud[:, :3], obj)
                points_to_boxes[indices, -7:-1] = all_detected_bboxes[i][None]
                points_to_boxes[indices, -1] = detected_class_ids[i][None] / NUM_CLASSES

        ret_dict = {
            'box_label_mask': box_label_mask.astype(np.float32),
            'center_label': t_bboxes[:, :3].astype(np.float32),
            'heading_class_label': np.zeros((MAX_NUM_OBJ,)).astype(np.int64),
            'heading_residual_label': np.zeros((MAX_NUM_OBJ,)),
            'point_obj_mask': point_obj_mask.astype(np.int64),
            'point_instance_label': point_instance_label.astype(np.int64),
            'sem_cls_label': class_ids.astype(np.int64),
            'size_gts': t_bboxes[:, 3:].astype(np.float32),
            'size_class_label': class_ids.astype(np.int64),
            'size_residual_label': size_residuals.astype(np.float32)
        }
        if self.agnostic:
            for key in ['sem_cls_label', 'size_class_label']:
                ret_dict[key] = np.zeros_like(ret_dict[key])
        ret_dict.update({
            "scan_ids": self.scan_ids[index],
            "point_clouds": point_cloud.astype(np.float32),
            "utterances": utterance,
            "tokens_positive": tokens_positive.astype(np.int64),
            "positive_map": positive_map.astype(np.float32),
            "is_grounding": 0,
            "num_point_class": len(self.scans[self.scan_ids[index]].pc) // 50000,
            "relation": "none",
            "target_name": "none",
            "target_id": 0,
            "all_bboxes": all_bboxes,
            "all_bbox_label_mask": all_bbox_label_mask.astype(np.bool8),
            "all_bbox_class_names": all_bbox_class_names,
            "distractor_ids": -np.ones((10,)).astype(int),
            "anchor_ids": -np.ones((10,)).astype(int),
            "all_logits": all_logits.astype(np.float32),
            "eul": np.array([0, 0, 0]).astype(np.float32),
            "all_detected_boxes": all_detected_bboxes.astype(np.float32),
            "all_detected_bbox_label_mask": all_detected_bbox_label_mask.astype(np.bool8),
            "all_detected_class_ids": detected_class_ids.astype(np.int64),
            "points_to_boxes": points_to_boxes.astype(np.float32)
        })
        if self.use_color:
            ret_dict.update({"og_color": scan.color[choices]})
        return ret_dict

    def __len__(self):
        """Return number of utterances."""
        return len(self.scan_ids)


def get_positive_map(tokenized, tokens_positive):
    """Construct a map of box-token associations."""
    positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
    for j, tok_list in enumerate(tokens_positive):
        (beg, end) = tok_list
        beg = int(beg)
        end = int(end)
        beg_pos = tokenized.char_to_token(beg)
        end_pos = tokenized.char_to_token(end - 1)
        if beg_pos is None:
            try:
                beg_pos = tokenized.char_to_token(beg + 1)
                if beg_pos is None:
                    beg_pos = tokenized.char_to_token(beg + 2)
            except:
                beg_pos = None
        if end_pos is None:
            try:
                end_pos = tokenized.char_to_token(end - 2)
                if end_pos is None:
                    end_pos = tokenized.char_to_token(end - 3)
            except:
                end_pos = None
        if beg_pos is None or end_pos is None:
            continue
        positive_map[j, beg_pos:end_pos + 1].fill_(1)

    positive_map = positive_map / (positive_map.sum(-1)[:, None] + 1e-12)
    return positive_map.numpy()


def rot_z(pc, theta):
    """Rotate along z-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1.0]
        ]),
        pc.T
    ).T


def box2points(box):
    """Convert box center/hwd coordinates to vertices (8x3)."""
    x_min, y_min, z_min = (box[:, :3] - (box[:, 3:] / 2)).transpose(1, 0)
    x_max, y_max, z_max = (box[:, :3] + (box[:, 3:] / 2)).transpose(1, 0)

    return np.stack((
        np.concatenate((x_min[:, None], y_min[:, None], z_min[:, None]), axis=1),
        np.concatenate((x_min[:, None], y_max[:, None], z_min[:, None]), axis=1),
        np.concatenate((x_max[:, None], y_min[:, None], z_min[:, None]), axis=1),
        np.concatenate((x_max[:, None], y_max[:, None], z_min[:, None]), axis=1),
        np.concatenate((x_min[:, None], y_min[:, None], z_max[:, None]), axis=1),
        np.concatenate((x_min[:, None], y_max[:, None], z_max[:, None]), axis=1),
        np.concatenate((x_max[:, None], y_min[:, None], z_max[:, None]), axis=1),
        np.concatenate((x_max[:, None], y_max[:, None], z_max[:, None]), axis=1)
    ), axis=1)


def points2box(box):
    """Convert vertices (Nx8x3) to box center/hwd coordinates (Nx6)."""
    return np.concatenate((
        (box.min(1) + box.max(1)) / 2,
        box.max(1) - box.min(1)
    ), axis=1)


def scannet_loader(iter_obj):
    """Load the scans in memory, helper function."""
    scan_id, scan_path, scannet = iter_obj
    print(scan_id)
    return Scan(scan_id, scan_path, scannet, True)


def save_data(filename, split):
    """Save all scans to pickle."""
    import multiprocessing as mp

    # Read all scan files
    scan_path = './dataset/language_grounding/scans/'
    with open('scannet/meta_data/scannetv2_%s.txt' % split) as f:
        scan_ids = [line.rstrip() for line in f]
    print('{} scans found.'.format(len(scan_ids)))
    scannet = ScanNetMappings()

    # Load data
    n_items = len(scan_ids)
    n_processes = 4  # min(mp.cpu_count(), n_items)
    pool = mp.Pool(n_processes)
    chunks = int(n_items / n_processes)
    print(n_processes, chunks)
    all_scans = dict()
    iter_obj = [
        (scan_id, scan_path, scannet)
        for scan_id in scan_ids
    ]
    for i, data in enumerate(
            pool.imap(scannet_loader, iter_obj, chunksize=chunks)
    ):
        all_scans[scan_ids[i]] = data
    pool.close()
    pool.join()

    # Save data
    print('pickle time')
    pickle_data(filename, scannet, all_scans)


def pickle_data(file_name, *args):
    """Use (c)Pickle to save multiple objects in a single file."""
    out_file = open(file_name, 'wb')
    cPickle.dump(len(args), out_file, protocol=2)
    for item in args:
        cPickle.dump(item, out_file, protocol=2)
    out_file.close()


def unpickle_data(file_name, python2_to_3=False):
    """Restore data previously saved with pickle_data()."""
    in_file = open(file_name, 'rb')
    if python2_to_3:
        size = cPickle.load(in_file, encoding='latin1')
    else:
        size = cPickle.load(in_file)

    for _ in range(size):
        if python2_to_3:
            yield cPickle.load(in_file, encoding='latin1')
        else:
            yield cPickle.load(in_file)
    in_file.close()


def main():
    SD = ScanNetDataset('val')
    SD.__getitem__(0)


if __name__=='__main__':
    main()
