"""Dataset and data loader for ReferIt3D classifiers."""

from copy import deepcopy
import csv

import numpy as np
from torch.utils.data import Dataset

from .scannet_classes import REL_ALIASES, VIEW_DEP_RELS
from .scannet_dataset import unpickle_data, rot_z


MAX_NUM_OBJ = 24


class ReferIt3DClassDataset(Dataset):
    """Dataset utilities for ReferIt3D classifiers."""

    def __init__(self, anno_file, split='train', filter_relations=[]):
        """Initialize dataset (here for ReferIt3D utterances)."""
        self.anno_file = anno_file  # sr3d, nr3d or sr3d+
        self.split = split
        self.filter_relations = filter_relations
        self.data_path = './dataset/language_grounding/'
        self.annos = self.load_annos()
        print('Loading %s files, take a breath!' % split)
        split = 'test' if split != 'train' else 'train'
        _, self.scans = unpickle_data(f'{self.data_path}/%s_scans.pkl' % split)

    def load_annos(self):
        """Load annotations."""
        split = 'train' if self.split == 'train' else 'test'
        with open('%s/extra/sr3d_%s_scans.txt' % (self.data_path, split)) as f:
            scan_ids = set(eval(f.read()))
        with open(self.data_path + 'refer_it_3d/%s.csv' % self.anno_file) as f:
            csv_reader = csv.reader(f)
            headers = next(csv_reader)
            headers = {header: h for h, header in enumerate(headers)}

            annos = [
                {
                    'scan_id': line[headers['scan_id']],
                    'target_id': int(line[headers['target_id']]),
                    'distractor_ids': eval(line[headers['distractor_ids']]),
                    'utterance': line[headers['utterance']],
                    'target': line[headers['instance_type']],
                    'anchors': eval(line[headers['anchors_types']]),
                    'anchor_ids': eval(line[headers['anchor_ids']]),
                    'relation': self._find_rel(line[headers['utterance']]),
                }
                for line in csv_reader
                if line[headers['scan_id']] in scan_ids
                and
                str(line[headers['mentions_target_class']]).lower() == 'true'
                and
                (
                    not self.filter_relations or
                    self._find_rel(line[headers['utterance']]) in self.filter_relations
                )
            ]
        return annos

    def __getitem__(self, index):
        """Get current batch for input index."""
        anno = self.annos[index]
        rel_name = self._find_rel(anno['utterance'])

        # Pointcloud
        scan_id = anno['scan_id']
        scan = deepcopy(self.scans[scan_id])
        if self.split == 'train':
            if rel_name in VIEW_DEP_RELS:
                theta = (2*np.random.rand() - 1) * 5
            else:
                theta = (2*np.random.rand() - 1) * 180
            scan.pc = rot_z(scan.pc, theta)
            scan.pc += np.random.random((3,))[None, :] - 0.5
            scan.pc *= 0.95 + 0.1*np.random.random((3,))[None, :]

        # "Target" and "distractor" boxes
        d_bboxes = np.zeros((MAX_NUM_OBJ, 6))
        tids = [anno['target_id']] + anno['distractor_ids']
        d_bboxes[:len(tids)] = np.stack([
            scan.get_object_bbox(tid).reshape(-1) for tid in tids
        ])
        d_mask = np.zeros((MAX_NUM_OBJ,))
        d_mask[:len(tids)] = 1

        # "Anchor" boxes
        a_bboxes = np.zeros((2, 6))
        tids = anno['anchor_ids']
        a_bboxes[:len(tids)] = np.stack([
            scan.get_object_bbox(tid).reshape(-1) for tid in tids
        ])
        a_mask = np.zeros((2,))
        a_mask[:len(tids)] = 1

        if self.split == 'train':  # jitter boxes
            d_bboxes *= (0.95 + 0.1*np.random.random((len(d_bboxes), 6)))
            a_bboxes *= (0.95 + 0.1*np.random.random((len(a_bboxes), 6)))

        labels = np.zeros_like(d_mask)
        labels[0] = 1

        ret_dict = {
            "scan_ids": anno['scan_id'],
            "target_boxes": d_bboxes.astype(np.float32),  # min-max
            "anchor_boxes": a_bboxes.astype(np.float32),  # min-max
            "t_mask": d_mask.astype(np.int64),
            "a_mask": a_mask.astype(np.int64),
            "labels": labels.astype(np.int64)
        }
        return ret_dict

    @staticmethod
    def _find_rel(utterance):
        utterance = ' ' + utterance.replace(',', ' ,') + ' '
        relation = "none"
        sorted_rel_list = sorted(REL_ALIASES, key=len, reverse=True)
        for rel in sorted_rel_list:
            if ' ' + rel + ' ' in utterance:
                relation = REL_ALIASES[rel]
                break
        return relation

    def __len__(self):
        """Return number of utterances."""
        return len(self.annos)
