'''
File Created: Monday, 25th November 2019 1:35:30 pm
Author: Dave Zhenyu Chen (zhenyu.chen@tum.de)
'''

import os
import sys
import time
import h5py
import json
import pickle
import random
import numpy as np
import multiprocessing as mp
import torch
import os.path as osp
import math
import scipy.interpolate
import scipy.ndimage
import clip
import open3d as o3d
import imageio
import gc
from torch.utils.data import Dataset
from glob import glob

from config.config import CONF
from utils.pc_utils import random_sampling, rotx, roty, rotz
from data.scannet.model_util_scannet import rotate_aligned_boxes, ScannetDatasetConfig, rotate_aligned_boxes_along_axis
from models.softgroup.ops import voxelization_idx
from models.long_clip.model import longclip
# data setting
DC = ScannetDatasetConfig()
# data path
SCANNET_V2_TSV = os.path.join(CONF.PATH.SCANNET_META, "scannetv2-labels.combined.tsv")

class ReferenceDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, idx):
        raise NotImplementedError

    def _get_raw2label(self):
        # mapping
        scannet_labels = DC.type2class.keys()
        scannet2label = {label: i for i, label in enumerate(scannet_labels)}

        lines = [line.rstrip() for line in open(SCANNET_V2_TSV)]
        lines = lines[1:]
        raw2label = {}
        for i in range(len(lines)):
            label_classes_set = set(scannet_labels)
            elements = lines[i].split('\t')
            raw_name = elements[1]
            nyu40_name = elements[7]
            # print(lines[i])
            # print(raw_name, nyu40_name)
            if nyu40_name not in label_classes_set:
                raw2label[raw_name] = scannet2label['others']
            else:
                raw2label[raw_name] = scannet2label[nyu40_name]
        return raw2label

    def _get_unique_multiple_lookup(self):
        all_sem_labels = {}
        cache = {}
        for data in self.scanrefer:
            scene_id = data["scene_id"]
            object_id = data["object_id"]
            object_name = " ".join(data["object_name"].split("_"))
            ann_id = data["ann_id"]

            if scene_id not in all_sem_labels:
                all_sem_labels[scene_id] = []

            if scene_id not in cache:
                cache[scene_id] = {}

            if object_id not in cache[scene_id]:
                cache[scene_id][object_id] = {}
                try:
                    all_sem_labels[scene_id].append(self.raw2label[object_name])
                except KeyError:
                    all_sem_labels[scene_id].append(17)

        # convert to numpy array
        all_sem_labels = {scene_id: np.array(all_sem_labels[scene_id]) for scene_id in all_sem_labels.keys()}

        unique_multiple_lookup = {}
        for data in self.scanrefer:
            scene_id = data["scene_id"]
            object_id = data["object_id"]
            object_name = " ".join(data["object_name"].split("_"))
            ann_id = data["ann_id"]

            try:
                sem_label = self.raw2label[object_name]
            except KeyError:
                sem_label = 17

            unique_multiple = 0 if (all_sem_labels[scene_id] == sem_label).sum() == 1 else 1

            # store
            if scene_id not in unique_multiple_lookup:
                unique_multiple_lookup[scene_id] = {}

            if object_id not in unique_multiple_lookup[scene_id]:
                unique_multiple_lookup[scene_id][object_id] = {}

            if ann_id not in unique_multiple_lookup[scene_id][object_id]:
                unique_multiple_lookup[scene_id][object_id][ann_id] = None

            unique_multiple_lookup[scene_id][object_id][ann_id] = unique_multiple

        return unique_multiple_lookup

    def _load_data(self, dataset_name):
        print("loading data...")
        self.raw2label = self._get_raw2label()

        if not CONF.pretrain_model_on:
            pretrained_path = CONF.PATH.PRETRAINED_TRAIN_DATA if self.split == "train" else CONF.PATH.PRETRAINED_VAL_DATA
            self.pretrained_data = torch.load(pretrained_path)

        # add scannet data
        print("loading point cloud data...")
        self.scene_list = sorted(list(set([data["scene_id"] for data in self.scanrefer])))


class ScannetReferenceDataset(ReferenceDataset):
       
    def __init__(self, 
        scanrefer, 
        scanrefer_new, 
        scanrefer_all_scene,
        data_root,
        prefix,
        suffix,
        voxel_cfg=None,
        training=True,
        with_label=True,
        repeat=1,
        name="ScanRefer",
        lang_num_max=8,
        augment=False,
        shuffle=False,
        scan2cad_rotation=None):

        # NOTE only feed the scan2cad_rotation when on the training mode and train split

        self.scanrefer = scanrefer
        self.scanrefer_new = scanrefer_new
        self.scanrefer_new_len = len(scanrefer_new)
        self.scanrefer_all_scene = scanrefer_all_scene # all scene_ids in scanrefer
        self.name = name
        self.lang_num_max = lang_num_max
        self.augment = augment
        self.scan2cad_rotation = scan2cad_rotation

        self.data_root = data_root
        self.prefix = prefix
        self.suffix = suffix
        self.voxel_cfg = voxel_cfg
        self.training = training
        self.with_label = with_label
        self.repeat = repeat
        self.mode = 'train' if training else 'test'
        # load data
        self._load_data(name)
        self.multiview_data = {}
        self.should_shuffle = shuffle
        self.depth_scale = 1000.0
        self.vis_threshold = 0.05

    def load(self, filename):
        return torch.load(filename)
       
    def __len__(self):
        #return len(self.scanrefer)
        return self.scanrefer_new_len

    def split_scene_new(self,  scanrefer_data):
        scanrefer_train_new = []
        scanrefer_train_new_scene, scanrefer_train_scene = [], []
        scene_id = ''
        lang_num_max = self.lang_num_max
        for data in scanrefer_data:
            if scene_id != data["scene_id"]:
                scene_id = data["scene_id"]
                if len(scanrefer_train_scene) > 0:
                    if self.should_shuffle:
                        random.shuffle(scanrefer_train_scene)
                    # print("scanrefer_train_scene", len(scanrefer_train_scene))
                    for new_data in scanrefer_train_scene:
                        if len(scanrefer_train_new_scene) >= lang_num_max:
                            scanrefer_train_new.append(scanrefer_train_new_scene)
                            scanrefer_train_new_scene = []
                        scanrefer_train_new_scene.append(new_data)
                    if len(scanrefer_train_new_scene) > 0:
                        scanrefer_train_new.append(scanrefer_train_new_scene)
                        scanrefer_train_new_scene = []
                    scanrefer_train_scene = []
            scanrefer_train_scene.append(data)
        if len(scanrefer_train_scene) > 0:
            if self.should_shuffle:
                random.shuffle(scanrefer_train_scene)
            # print("scanrefer_train_scene", len(scanrefer_train_scene))
            for new_data in scanrefer_train_scene:
                if len(scanrefer_train_new_scene) >= lang_num_max:
                    scanrefer_train_new.append(scanrefer_train_new_scene)
                    scanrefer_train_new_scene = []
                scanrefer_train_new_scene.append(new_data)
            if len(scanrefer_train_new_scene) > 0:
                scanrefer_train_new.append(scanrefer_train_new_scene)
                scanrefer_train_new_scene = []
        return scanrefer_train_new

    def shuffle_data(self):
        # if self.shuffled:
        #     return
        # SElf.shuffled = True
        print('shuffle dataset data(lang)', flush=True)
        self.scanrefer_new = self.split_scene_new(self.scanrefer)
        if self.should_shuffle:
            random.shuffle(self.scanrefer_new)
        assert len(self.scanrefer_new) == self.scanrefer_new_len, 'assert scanrefer length right'
        print('shuffle done', flush=True)

    def dataAugment(self, xyz, jitter=False, flip=False, rot=False, scale=False, prob=1.0):
        m = np.eye(3)
        if jitter and np.random.rand() < prob:
            m += np.random.randn(3, 3) * 0.1
        if flip and np.random.rand() < prob:
            m[0][0] *= np.random.randint(0, 2) * 2 - 1
        if rot and np.random.rand() < prob:
            theta = np.random.rand() * 2 * math.pi
            m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
                              [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]])

        else:
            # Empirically, slightly rotate the scene can match the results from checkpoint
            theta = 0.35 * math.pi
            m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
                              [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]])
        if scale and np.random.rand() < prob:
            scale_factor = np.random.uniform(0.95, 1.05)
            xyz = xyz * scale_factor
        return np.matmul(xyz, m)
    
    def elastic(self, x, gran, mag):
        blur0 = np.ones((3, 1, 1)).astype('float32') / 3
        blur1 = np.ones((1, 3, 1)).astype('float32') / 3
        blur2 = np.ones((1, 1, 3)).astype('float32') / 3

        bb = np.abs(x).max(0).astype(np.int32) // gran + 3
        noise = [np.random.randn(bb[0], bb[1], bb[2]).astype('float32') for _ in range(3)]
        noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
        noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
        ax = [np.linspace(-(b - 1) * gran, (b - 1) * gran, b) for b in bb]
        interp = [
            scipy.interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0)
            for n in noise
        ]

        def g(x_):
            return np.hstack([i(x_)[:, None] for i in interp])

        return x + g(x) * mag

    def crop(self, xyz, step=32):
        xyz_offset = xyz.copy()
        valid_idxs = xyz_offset.min(1) >= 0
        assert valid_idxs.sum() == xyz.shape[0]
        spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
        room_range = xyz.max(0) - xyz.min(0)
        while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
            step_temp = step
            if valid_idxs.sum() > 1e6:
                step_temp = step * 2
            offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
            xyz_offset = xyz + offset
            valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
            spatial_shape[:2] -= step_temp
        return xyz_offset, valid_idxs

    def getInstanceInfo(self, xyz, instance_label, semantic_label):
        # pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0
        instance_pointnum = []
        instance_cls = []
        # max(instance_num, 0) to support instance_label with no valid instance_id
        instance_num = max(int(instance_label.max()) + 1, 0)
        for i_ in range(instance_num):
            inst_idx_i = np.where(instance_label == i_)
            # xyz_i = xyz[inst_idx_i]
            # pt_mean[inst_idx_i] = xyz_i.mean(0)
            instance_pointnum.append(inst_idx_i[0].size)
            cls_idx = inst_idx_i[0][0]
            instance_cls.append(semantic_label[cls_idx])
        # pt_offset_label = pt_mean - xyz
        # return instance_num, instance_pointnum, instance_cls, pt_offset_label
        return instance_num, instance_pointnum, instance_cls
    
    
    def getCroppedInstLabel(self, instance_label, valid_idxs):
        instance_label = instance_label[valid_idxs]
        j = 0
        while (j < instance_label.max()):
            if (len(np.where(instance_label == j)[0]) == 0):
                instance_label[instance_label == instance_label.max()] = j
            j += 1
        return instance_label

    def transform_train(self, xyz, rgb, semantic_label, instance_label, aug_prob=1.0):
        xyz_middle = self.dataAugment(xyz, True, True, True, aug_prob)
        xyz = xyz_middle * self.voxel_cfg.scale
        if np.random.rand() < aug_prob:
            xyz = self.elastic(xyz, 6, 40.)
            xyz = self.elastic(xyz, 20, 160.)
        # xyz_middle = xyz / self.voxel_cfg.scale
        xyz = xyz - xyz.min(0)
        max_tries = 5
        while (max_tries > 0):
            xyz_offset, valid_idxs = self.crop(xyz)
            if valid_idxs.sum() >= self.voxel_cfg.min_npoint:
                xyz = xyz_offset
                break
            max_tries -= 1
        if valid_idxs.sum() < self.voxel_cfg.min_npoint:
            return None
        xyz = xyz[valid_idxs]
        xyz_middle = xyz_middle[valid_idxs]
        rgb = rgb[valid_idxs]
        semantic_label = semantic_label[valid_idxs]
        instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
        return xyz, xyz_middle, rgb, semantic_label, instance_label, valid_idxs

    def transform_test(self, xyz, rgb, semantic_label, instance_label):
        xyz_middle = self.dataAugment(xyz, False, False, False, False)
        xyz = xyz_middle * self.voxel_cfg.scale
        xyz -= xyz.min(0)
        valid_idxs = np.ones(xyz.shape[0], dtype=bool)
        instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
        return xyz, xyz_middle, rgb, semantic_label, instance_label, valid_idxs
    
    def get_ref_mask(self, instance_label, object_id):
        ref_lbl = torch.zeros_like(instance_label)
        ref_lbl[instance_label == object_id] = 1
        gt_mask = ref_lbl.float()
        return gt_mask

    def load_depth_maps(self, depth_maps_paths):
        depth_maps = []
        for depth_map_path_i in depth_maps_paths:
            depth_path = os.path.join(depth_map_path_i)
            depth_maps.append(torch.from_numpy(imageio.imread(depth_path) / self.depth_scale).cuda())
        return torch.stack(depth_maps)
    
    
    def adjust_intrinsic(self, intrinsic, original_resolution, new_resolution):
        if original_resolution == new_resolution:
            return intrinsic
        
        resize_width = int(math.floor(new_resolution[1] * float(
                        original_resolution[0]) / float(original_resolution[1])))
        
        adapted_intrinsic = intrinsic.copy()
        adapted_intrinsic[0, 0] *= float(resize_width) / float(original_resolution[0])
        adapted_intrinsic[1, 1] *= float(new_resolution[1]) / float(original_resolution[1])
        adapted_intrinsic[0, 2] *= float(new_resolution[0] - 1) / float(original_resolution[0] - 1)
        adapted_intrinsic[1, 2] *= float(new_resolution[1] - 1) / float(original_resolution[1] - 1)
        return adapted_intrinsic
    
    def get_points_visibility(self, points, intrinsics, extrinsics, depth, height, width):
        word2cam_mat = torch.einsum('bij, jk -> bik', torch.einsum('bij,bjk -> bik', intrinsics, extrinsics), points.T).permute(0,2,1)
        size = (word2cam_mat.shape[0], word2cam_mat.shape[1])
        mask = (word2cam_mat[:, :, 2] != 0).reshape(size[0]*size[1])
        projected_points = torch.stack([(word2cam_mat[:, :, 0].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size), 
                        (word2cam_mat[:, :, 1].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size)]).permute(1,2,0).long()
        inside_mask = ((projected_points[:,:,0] < width)*(projected_points[:,:,0] > 0)*(projected_points[:,:,1] < height)*(projected_points[:,:,1] >0) == 1)
        point_depth = word2cam_mat[:, :, 2]
        num_frames = depth.shape[0]
        for frame_id in range(num_frames):
            points_in_frame_mask = inside_mask[frame_id].clone()
            points_in_frame = (projected_points[frame_id][points_in_frame_mask])
            depth_in_frame = point_depth[frame_id][points_in_frame_mask]
            visibility_mask = (torch.abs(depth[frame_id][points_in_frame[:,1].long(), points_in_frame[:,0].long()]
                                        - depth_in_frame) <= \
                                        self.vis_threshold)
            inside_mask[frame_id][points_in_frame_mask] = visibility_mask
        
        del word2cam_mat, projected_points, mask
        gc.collect()
        torch.cuda.empty_cache()
        return inside_mask

    def __getitem__(self, idx):
        start = time.time()
        data_dict = {}
        lang_num = len(self.scanrefer_new[idx])
        scene_id = self.scanrefer_new[idx][0]["scene_id"]
        clip_tokens = []
        # object_id_list = []
        object_name_list = []
        ann_id_list = []
        # easy_hard_list = []
        # dep_indep_list = []

        for i in range(self.lang_num_max):
            if i < lang_num:
                object_name = " ".join(self.scanrefer_new[idx][i]["object_name"].split("_"))
                ann_id = self.scanrefer_new[idx][i]["ann_id"]
                clip_token = clip.tokenize(self.scanrefer_new[idx][i]["description"].strip(), truncate=True)[0]

            clip_tokens.append(clip_token)
            object_name_list.append(object_name)
            ann_id_list.append(int(ann_id))

        # ------------------------------- LABELS ------------------------------    
        object_cat_list = []
        for i in range(self.lang_num_max):
            object_cat = 17
            if object_name_list[i] in self.raw2label:
                object_cat = self.raw2label[object_name_list[i]]
            else:
                for tmp in object_name_list[i].split(" "):
                    if tmp in self.raw2label:
                        object_cat = self.raw2label[tmp]
            object_cat_list.append(object_cat)

        istrain = 0
        if self.mode == "train":
            istrain = 1

        data_dict["istrain"] = istrain
        data_dict["scan_idx"] = torch.from_numpy(np.array(idx).astype(np.int64))
        data_dict["lang_num"] = torch.from_numpy(np.array(lang_num).astype(np.int64))
        data_dict["load_time"] = time.time() - start
        # data_dict["object_id_list"] = torch.tensor(object_id_list, dtype=torch.int64)
        data_dict["ann_id_list"] = torch.tensor(ann_id_list, dtype=torch.int64)
        data_dict["object_cat_list"] = torch.tensor(object_cat_list, dtype=torch.int64)
        # =====================
        data_dict["scan_id"] = scene_id
        data_dict["clip_token"] = torch.stack(clip_tokens, 0)
        return data_dict
    
    def collate_fn(self, batch):
        start = time.time()
        data_dict = {}
        load_time, istrain, lang_num, scan_idx = [], [], [], []
        object_id_list, ann_id_list, object_cat_list = [],[],[]
        scan_ids = []
        clip_tokens, longclip_tokens = [], []

        batch_id = 0
        for data in batch:
            if data is None:
                continue

            scan_ids.append(data["scan_id"])
            istrain.append(data["istrain"])
            lang_num.append(data["lang_num"])
            load_time.append(data["load_time"])
            scan_idx.append(data["scan_idx"])

            # object_id_list.append(data["object_id_list"])
            ann_id_list.append(data["ann_id_list"])
            object_cat_list.append(data["object_cat_list"])
            clip_tokens.append(data["clip_token"])
            batch_id += 1

        assert batch_id > 0, 'empty batch'
        
        # data_dict = {}
        data_dict["istrain"] = istrain
        data_dict["lang_num"] = torch.tensor(lang_num, dtype=torch.int64)
        data_dict["scan_idx"] = torch.stack(scan_idx, 0)
        # data_dict["object_id_list"] = torch.stack(object_id_list, 0)
        data_dict["ann_id_list"] = torch.stack(ann_id_list, 0)
        data_dict["object_cat_list"] = torch.stack(object_cat_list, 0)
        
        data_dict["scene_id"] = scan_ids
        data_dict["clip_token"] = torch.stack(clip_tokens, 0)

        collate_time = time.time() - start
        load_time.append(collate_time)
        data_dict["load_time"] = torch.tensor(load_time, dtype=torch.float64)
        data_dict["split"] = self.prefix

        return data_dict

