# Copyright (c) Facebook, Inc. and its affiliates.


""" 
Modified from https://github.com/facebookresearch/votenet
Dataset for 3D object detection on SUN RGB-D (with support of vote supervision).

A sunrgbd oriented bounding box is parameterized by (cx,cy,cz), (l,w,h) -- (dx,dy,dz) in upright depth coord
(Z is up, Y is forward, X is right ward), heading angle (from +X rotating to -Y) and semantic class

Point clouds are in **upright_depth coordinate (X right, Y forward, Z upward)**
Return heading class, heading residual, size class and size residual for 3D bounding boxes.
Oriented bounding box is parameterized by (cx,cy,cz), (l,w,h), heading_angle and semantic class label.
(cx,cy,cz) is in upright depth coordinate
(l,h,w) are *half length* of the object sizes
The heading angle is a rotation rad from +X rotating towards -Y. (+X is 0, -Y is pi/2)

Author: Charles R. Qi
Date: 2019

"""
import os
import sys
import numpy as np
from torch.utils.data import Dataset
import scipy.io as sio  # to load .mat files for depth points
import cv2
import random

import utils.pc_util as pc_util
from utils.sunrgbd_pc_util import write_oriented_bbox
from utils.random_cuboid import RandomCuboid
from utils.pc_util import shift_scale_points, scale_points
from utils.box_util import (
    flip_axis_to_camera_tensor,
    get_3d_box_batch_tensor,
    flip_axis_to_camera_np,
    get_3d_box_batch_np,
)


MEAN_COLOR_RGB = np.array([0.5, 0.5, 0.5])  # sunrgbd color is in 0~1
DATA_PATH_V1 = "" ## Replace with path to dataset
DATA_PATH_V2 = "" ## Not used in the codebase.
DATA_PATH_ImageNet = "" ## Replace with path to ImageNet
DATA_PATH_Pseudo_label = "" ## Replace with path to Pseudo label



class SunrgbdDatasetConfig(object):
    def __init__(self):
        self.num_semcls = 20
        self.num_angle_bin = 12
        self.max_num_obj = 64
        self.open_class_top_k = 10
        self.type2class = {
														'toilet': 0,
														'bed': 1,
														'chair': 2,
														'bathtub': 3,
														'sofa': 4,
														'dresser': 5,
														'scanner': 6,
														'fridge': 7,
														'lamp': 8,
														'desk': 9,
														'table': 10,
														'stand': 11,
														'cabinet': 12,
														'counter': 13,
														'bin': 14,
														'bookshelf': 15,
														'pillow': 16,
														'microwave': 17,
														'sink': 18,
														'stool': 19}
								
        self.type2imagenet_id = {
														0: ["n04447861"],																												# toilet
														1: ["n03225988"], 																											# bed
														2: ["n02791124","n03376595","n04099969"],																# chair
														3: ["n02808440"],																												# bathtab
														4: ["n04256520"],																												# sofa
														5: ["n03237340"],																												# dresser
														6: ["n04142731"],																												# scanner
														7: ["n03273913"],																												# fridge
														8: ["n04380533","n03637318"],																						# lamp
														9: ["n03179701"],																												# desk
														10: ["n03201208"],																											# table
														# 11: [""], not matched class in ImageNet of stand											# night_stand
														12: ["n03018349","n03337140"],																					# cabinet
														# 13: [""], not matched class in ImageNet of counter 														
														14: ["n02747177"],																											# bin
														15: ["n02871439"],																											# bookshelf
														16: ["n03938244"],																											# pillow
														17: ["n03761084"],																											# microwave
														18: ["n02998563"],																											# sink
														19: ["n04326896"],																											# stool
														}
        
        self.class2type = {self.type2class[t]: t for t in self.type2class}
        self.type2onehotclass = {
														'toilet': 0,
														'bed': 1,
														'chair': 2,
														'bathtub': 3,
														'sofa': 4,
														'dresser': 5,
														'scanner': 6,
														'fridge': 7,
														'lamp': 8,
														'desk': 9,
														'table': 10,
														'stand': 11,
														'cabinet': 12,
														'counter': 13,
														'bin': 14,
														'bookshelf': 15,
														'pillow': 16,
														'microwave': 17,
														'sink': 18,
														'stool': 19}

    def angle2class(self, angle):
        """Convert continuous angle to discrete class
        [optinal] also small regression number from
        class center angle to current angle.

        angle is from 0-2pi (or -pi~pi), class center at 0, 1*(2pi/N), 2*(2pi/N) ...  (N-1)*(2pi/N)
        returns class [0,1,...,N-1] and a residual number such that
            class*(2pi/N) + number = angle
        """
        num_class = self.num_angle_bin
        angle = angle % (2 * np.pi)
        assert angle >= 0 and angle <= 2 * np.pi
        angle_per_class = 2 * np.pi / float(num_class)
        shifted_angle = (angle + angle_per_class / 2) % (2 * np.pi)
        class_id = int(shifted_angle / angle_per_class)
        residual_angle = shifted_angle - (
            class_id * angle_per_class + angle_per_class / 2
        )
        return class_id, residual_angle

    def class2angle(self, pred_cls, residual, to_label_format=True):
        """Inverse function to angle2class"""
        num_class = self.num_angle_bin
        angle_per_class = 2 * np.pi / float(num_class)
        angle_center = pred_cls * angle_per_class
        angle = angle_center + residual
        if to_label_format and angle > np.pi:
            angle = angle - 2 * np.pi
        return angle

    def class2angle_batch(self, pred_cls, residual, to_label_format=True):
        num_class = self.num_angle_bin
        angle_per_class = 2 * np.pi / float(num_class)
        angle_center = pred_cls * angle_per_class
        angle = angle_center + residual
        if to_label_format:
            mask = angle > np.pi
            angle[mask] = angle[mask] - 2 * np.pi
        return angle

    def class2anglebatch_tensor(self, pred_cls, residual, to_label_format=True):
        return self.class2angle_batch(pred_cls, residual, to_label_format)

    def box_parametrization_to_corners(self, box_center_unnorm, box_size, box_angle):
        box_center_upright = flip_axis_to_camera_tensor(box_center_unnorm)
        boxes = get_3d_box_batch_tensor(box_size, box_angle, box_center_upright)
        return boxes

    def box_parametrization_to_corners_np(self, box_center_unnorm, box_size, box_angle):
        box_center_upright = flip_axis_to_camera_np(box_center_unnorm)
        boxes = get_3d_box_batch_np(box_size, box_angle, box_center_upright)
        return boxes

    def my_compute_box_3d(self, center, size, heading_angle):
        R = pc_util.rotz(-1 * heading_angle)
        l, w, h = size
        x_corners = [-l, l, l, -l, -l, l, l, -l]
        y_corners = [w, w, -w, -w, w, w, -w, -w]
        z_corners = [h, h, h, h, -h, -h, -h, -h]
        corners_3d = np.dot(R, np.vstack([x_corners, y_corners, z_corners]))
        corners_3d[0, :] += center[0]
        corners_3d[1, :] += center[1]
        corners_3d[2, :] += center[2]
        return np.transpose(corners_3d)


def save_img_with_bbox(img, bbox, filename, class2type):
	for ind in range(bbox.shape[0]):
		top_left = (int(bbox[ind, 0]), int(bbox[ind,1]))
		down_right = (int(bbox[ind, 2]), int(bbox[ind,3]))
		
		cv2.rectangle(img, top_left, down_right, (0,255,0), 2)
		cv2.putText(img, '%d %s'%(ind, class2type[bbox[ind,4]]), (max(int(bbox[ind,0]),15), max(int(bbox[ind,1]),15)), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,0,0), 2)
		cv2.imwrite(filename, img)

def img_norm(img):
	imagenet_std = [0.229, 0.224, 0.225]
	imagenet_mean = [0.485, 0.456, 0.406]
	img = ((img/255) - imagenet_mean) / imagenet_std
	return img

def flip_horizon(img, bbox):
	img = cv2.flip(img, 1)
	out_bbox = bbox.copy()
	out_bbox[:,2] = img.shape[1] - bbox[:,0]
	out_bbox[:,0] = img.shape[1] - bbox[:,2]
	return img, out_bbox
	
def flip_vertical(img, bbox):
	img = cv2.flip(img, 0)
	out_bbox = bbox.copy()
	out_bbox[:,3] = img.shape[0] - bbox[:,1]
	out_bbox[:,1] = img.shape[0] - bbox[:,3]
	return img, out_bbox

def random_rotate(img, bbox):
	img_width = img.shape[1]
	img_hight = img.shape[0]
	bbox_num = bbox.shape[0]
	
	center = (img_width/2, img_hight/2)
	angle = random.gauss(0, 10)
	rot_mat = cv2.getRotationMatrix2D(center, angle, 1)
	
	corner = np.ones([4, bbox_num, 3])
	#left_top
	corner[0, :, :2] = bbox[:,:2]
		
	#right_down
	corner[1, :, :2] = bbox[:,2:4]
	
	#left_down
	left_down_ind = [0,3]
	corner[2, :, :2] = bbox[:,left_down_ind]
	
	#right_top
	right_top_ind = [2,1]
	corner[3, :, :2] = bbox[:,right_top_ind]
	
	rotated_corner = np.matmul(corner, rot_mat.T)
	
	out_bbox = np.zeros([bbox_num, 5])
	
	out_bbox[:, 0] = np.min(rotated_corner[: ,:, 0], axis=0)
	out_bbox[:, 1] = np.min(rotated_corner[: ,:, 1], axis=0)
	out_bbox[:, 2] = np.max(rotated_corner[: ,:, 0], axis=0)
	out_bbox[:, 3] = np.max(rotated_corner[: ,:, 1], axis=0)
	out_bbox[:, 4] = bbox[:,4]
	
	width_ind = [0, 2]
	heigh_ind = [1, 3]
	
	out_bbox[:, width_ind] = np.clip(out_bbox[:, width_ind], 0, img_width-1)
	out_bbox[:, heigh_ind] = np.clip(out_bbox[:, heigh_ind], 0, img_hight-1)
	
	out_img = cv2.warpAffine(img, rot_mat, (img_width, img_hight))
	
	return out_img, out_bbox

def scale_img_bbox(img, bbox, scale):
	scaled_width = int(img.shape[1] * scale) - 1
	scaled_hight = int(img.shape[0] * scale) - 1
	if scaled_width < 10:
		scaled_width = 10
	if scaled_hight < 10:
		scaled_hight = 10
	
	dsize =( scaled_width, scaled_hight )
	try:
		img = cv2.resize(img, dsize)
	except:
		print(dsize)
		print(img.shape)
		exit()
	
	bbox[:,:4] *= scale
	
	img_width = img.shape[1]
	img_hight = img.shape[0]	
	width_ind = [0, 2]
	heigh_ind = [1, 3]
	
	bbox[:, width_ind] = np.clip(bbox[:, width_ind], 0, img_width-1)
	bbox[:, heigh_ind] = np.clip(bbox[:, heigh_ind], 0, img_hight-1)
	return img, bbox

def random_scale(img, bbox):
	scale = random.uniform(1, 2)
	if random.random() < 0.5:
		scale = 1/scale
	
	img, bbox = scale_img_bbox(img, bbox, scale)
	
	return img, bbox
	

def img_det_aug(img, bbox, split, class2type):
	if split in ["train"]:
		# Random Horizontally Flip
		if random.random() < 0.5:
			img, bbox = flip_horizon(img, bbox)
		
		# Random Horizontally Flip
		if random.random() < 0.5:
			img, bbox = flip_vertical(img, bbox)
		
		# Random Rotate
		if random.random() < 0.5:
			img, bbox = random_rotate(img, bbox)
		#save_img_with_bbox(img.copy(), bbox, "after_rotate.jpg", class2type)
		
		# Random Scale
		if random.random() < 0.5:
			img, bbox = random_scale(img, bbox)
		
	# Norm is for both training and testing		
	img = img_norm(img)
	
	return img, bbox
	


class SunrgbdDetectionDataset(Dataset):
    def __init__(
        self,
        dataset_config,
        split_set="train",
        root_dir=None,
        num_points=20000,
        use_color=False,
        use_height=False,
        use_v1=True,
        augment=False,
        use_random_cuboid=False,
        random_cuboid_min_points=30000,
    ):
        assert num_points <= 50000
        assert split_set in ["train", "val", "trainval"]
        self.dataset_config = dataset_config
        self.use_v1 = use_v1
        self.split_set = split_set

        if root_dir is None:
            root_dir = DATA_PATH_V1 if use_v1 else DATA_PATH_V2
            root_dir_imagenet = DATA_PATH_ImageNet

        self.data_path = root_dir + "_%s" % (split_set)
        self.data_path_imagenet = root_dir_imagenet + "_%s" % (split_set)
        self.data_path_pseudo = DATA_PATH_Pseudo_label

        if split_set in ["train"]:
            self.scan_names = sorted(
                list(
                    set([os.path.basename(x)[0:6] for x in os.listdir(self.data_path)])
                )
            )
            
            self.imagenet_scan_name_dict = {}
            self.imagenet_classes = sorted(
                list(
                    set([os.path.basename(x) for x in os.listdir(self.data_path_imagenet)])
                )
            )
            self.imagenet_class_num = len(self.imagenet_classes)
            
            for cls in self.imagenet_classes:
            	scan_names = sorted(
            											set(
            													[os.path.basename(x)[:18] for x in os.listdir(os.path.join(self.data_path_imagenet, cls))]
            												 )
            										 )
            	self.imagenet_scan_name_dict[cls] = scan_names
            
            
        elif split_set in ["val"]:
            self.scan_names = sorted(
                list(
                    set([os.path.basename(x)[0:6] for x in os.listdir(self.data_path)])
                )
            )
            self.scan_names_imagenet = []
            self.imagenet_scan_name_dict = {}
        
        elif split_set in ["trainval"]:
            # combine names from both
            sub_splits = ["train", "val"]
            all_paths = []
            for sub_split in sub_splits:
                data_path = self.data_path.replace("trainval", sub_split)
                basenames = sorted(
                    list(set([os.path.basename(x)[0:6] for x in os.listdir(data_path)]))
                )
                basenames = [os.path.join(data_path, x) for x in basenames]
                all_paths.extend(basenames)
            all_paths.sort()
            self.scan_names = all_paths

        self.num_points = num_points
        self.augment = augment
        self.use_color = use_color
        self.use_height = use_height
        self.use_random_cuboid = use_random_cuboid
        self.random_cuboid_augmentor = RandomCuboid(
            min_points=random_cuboid_min_points,
            aspect=0.75,
            min_crop=0.75,
            max_crop=1.0,
        )
        self.center_normalizing_range = [
            np.zeros((1, 3), dtype=np.float32),
            np.ones((1, 3), dtype=np.float32),
        ]
        self.max_num_obj = 64

    def __len__(self):
        return len(self.scan_names)

    # Load data from ImageNet
    def imagenet_item(self, idx=0):
        # Random Select ImageNet class and image
        sunrgbd_or_imagenet = random.random()
        if sunrgbd_or_imagenet < 0.9:
        	cls_idx = random.randint(0, len(self.dataset_config.type2imagenet_id.keys()) - 1)
        	sel_key = list(self.dataset_config.type2imagenet_id.keys())[cls_idx]
        	chosen_class_list = self.dataset_config.type2imagenet_id[sel_key]
        	
        	sel_id = 0
        	if len(chosen_class_list) > 0:
        		sel_id = random.randint(0, len(chosen_class_list) - 1)
        	chosen_class = chosen_class_list[sel_id]

        else:
        	cls_idx = random.randint(0, self.imagenet_class_num - 1)
        	chosen_class = self.imagenet_classes[cls_idx]
        
        idx = random.randint(0, len(self.imagenet_scan_name_dict[chosen_class])-1)
        scan_name = self.imagenet_scan_name_dict[chosen_class][idx]
        
        if scan_name.startswith("/"):
            scan_path = scan_name
        else:
            scan_path = os.path.join(self.data_path_imagenet, chosen_class, scan_name)
        
        # Load image and bbox_2d
        bboxes_2d_with_label = np.load(scan_path + "_2d_bbox.npy") # K,5
        
        if len(bboxes_2d_with_label.shape) == 1:
            bboxes_2d_with_label = np.expand_dims(bboxes_2d_with_label, axis=0)
        
        bbox_num = bboxes_2d_with_label.shape[0]
        
        image = np.zeros([255,255,3], dtype=np.float32)
        bboxes_2d = np.zeros([64,4], dtype=np.float32)
        bboxes_2d_label = np.zeros([64], dtype=np.float32)
        mask = np.ones([255,255], dtype=np.bool)
        img = cv2.imread(scan_path + ".JPEG")

        # Image augmentation such as: normalization
        img, bboxes_2d_with_label = img_det_aug(img, bboxes_2d_with_label, self.split_set, self.dataset_config.class2type)
        bboxes_2d_with_label[:,:4] = 0
        
        width_ratio = img.shape[1] / 255
        height_ratio = img.shape[0] / 255
        
        #print(img.shape)
        
        if max(width_ratio, height_ratio) > 1:
            scale = 1 / max(width_ratio, height_ratio)
            img, bboxes_2d_with_label = scale_img_bbox(img, bboxes_2d_with_label, scale)
        #print(img.shape)
        #input()
        
        # Validate loaded imagenet image
        '''
        if True:
        	cv2.imwrite("imagenet.jpg", img)
        	print(scan_name)
        	print(bboxes_2d_with_label[0,:5])
        	#print(self.dataset_config.class2type[bboxes_2d_with_label[0,4]])
        	input()
        '''
        
        # Padding with image
        img_width = img.shape[1]
        img_height = img.shape[0]
        image[:img_height, :img_width, :3] = img
        image = np.transpose(image,[2,0,1])
        
        mask[:img_height, :img_width] = False
        
        #Padding with box & Normalizing Box
        bboxes_2d[:bboxes_2d_with_label.shape[0],:] = bboxes_2d_with_label[:,:4]
        bboxes_2d[:,0] /= img_width
        bboxes_2d[:,2] /= img_width
        bboxes_2d[:,1] /= img_height
        bboxes_2d[:,3] /= img_height
        bboxes_2d_label[:bboxes_2d_with_label.shape[0]] = bboxes_2d_with_label[:,4]
        
        '''
        print(image.shape)
        print(mask.shape)
        print(bboxes_2d.shape)
        print(bboxes_2d_label.shape)
        print(bbox_num)
        '''
        return image,mask,bboxes_2d,bboxes_2d_label,bbox_num


    def load_n_imagenet_item(self, num):
        image_list=[]
        mask_list=[]
        bboxes_2d_list=[]
        bboxes_2d_label_list=[]
        bbox_num_list=[]
        for ind in range(num):
        	cur_image, cur_mask, cur_bboxes_2d, cur_bboxes_2d_label, cur_bbox_num = self.imagenet_item()
        	
        	#cur_image = np.expand_dims(cur_image, axis=0)
        	#cur_mask = np.expand_dims(cur_mask, axis=0)
        	#cur_bboxes_2d = np.expand_dims(cur_bboxes_2d, axis=0)
        	#cur_bboxes_2d_label = np.expand_dims(cur_bboxes_2d_label, axis=0)
        	#cur_bbox_num = np.expand_dims(cur_bbox_num, axis=0)
        	
        	image_list.append(cur_image)
        	mask_list.append(cur_mask)
        	bboxes_2d_list.append(cur_bboxes_2d)
        	bboxes_2d_label_list.append(cur_bboxes_2d_label)
        	bbox_num_list.append(cur_bbox_num)
        
        image_list = np.stack(image_list, axis=0)
        mask_list = np.stack(mask_list, axis=0)
        bboxes_2d_list = np.stack(bboxes_2d_list, axis=0)
        bboxes_2d_label_list = np.stack(bboxes_2d_label_list, axis=0)
        bbox_num_list = np.array(bbox_num_list)
        
        return image_list, mask_list, bboxes_2d_list, bboxes_2d_label_list, bbox_num_list
        	
        	
            
    def __getitem__(self, idx):
        while True:
        	scan_name = self.scan_names[idx]
        	if scan_name.startswith("/"):
        		scan_path = scan_name
        	else:
        		scan_path = os.path.join(self.data_path, scan_name)
        	
        	bboxes_2d_with_label = np.load(scan_path + "_2d_bbox.npy") # K,5
        	
        	if self.split_set in ["train"]:
        		# remove open class
        		seen_class_index = np.where(bboxes_2d_with_label[:, -1] >= self.dataset_config.open_class_top_k)[0]
        		bboxes_2d_with_label = bboxes_2d_with_label[seen_class_index, :]
        	
        	bbox_num = 0
        	# Load pseudo label
        	pseudo_label_filename = os.path.join(self.data_path_pseudo, scan_name+"_pred_bbox.npy")
        	if os.path.exists(pseudo_label_filename):
        		bboxes_pseudo = np.load(pseudo_label_filename)  # K,8
        		# correct width, height, length is 1/2 of its ori size
        		bboxes_pseudo[:,3:6] /= 2
        		bbox_num += bboxes_pseudo.shape[0]
        	
        	bbox_num += bboxes_2d_with_label.shape[0]
        	
        	if bbox_num <= 0:
        		idx = random.randint(0, len(self.scan_names)-1)
        	else:
        		break
        
        point_cloud = np.load(scan_path + "_pc.npz")["pc"]  # Nx6
        bboxes = np.load(scan_path + "_bbox.npy")  # K,8
        
        # Load image and bbox_2d
        bboxes_2d_with_label = np.load(scan_path + "_2d_bbox.npy") # K,5
        
        # remove open class
        if self.split_set in ["train"]:
        	seen_class_index = np.where(bboxes_2d_with_label[:, -1] >= self.dataset_config.open_class_top_k)[0]
        	bboxes_2d_with_label = bboxes_2d_with_label[seen_class_index, :]
        	bboxes = bboxes[seen_class_index, :]
        
        
        # Load pseudo label
        pseudo_label_filename = os.path.join(self.data_path_pseudo, scan_name+"_pred_bbox.npy")
        pseudo_bbox_num = 0
        if os.path.exists(pseudo_label_filename):
        	bboxes_pseudo = np.load(pseudo_label_filename)  # K,8
        	pseudo_bbox_num = bboxes_pseudo.shape[0]
        	
        	# correct width, height, length is 1/2 of its ori size
        	bboxes_pseudo[:,3:6] /= 2
        	
        	bboxes = np.concatenate([bboxes, bboxes_pseudo], axis=0)
        	# Validata Point Cloud Data
        	'''
        	print(pseudo_label_filename)
        	#print(self.dataset_config.class2type[bboxes_pseudo[0, 7]])
        	
        	# gt
        	bboxes[:,3:6] *= 2
        	bboxes[:, 6] *= -1
        	write_oriented_bbox(bboxes[:,:7], "gt.ply")
        	
        	# pseudo label
        	bboxes_pseudo[:,3:6] *= 2
        	bboxes_pseudo[:, 6] *= -1
        	write_oriented_bbox(bboxes_pseudo[:,:7], "pseudo.ply")
        	input()
        	'''
        	
        
        
        bbox_num = bboxes_2d_with_label.shape[0]
        image = np.zeros([531,730,3], dtype=np.float32)
        bboxes_2d = np.zeros([64,4], dtype=np.float32)
        bboxes_2d_label = np.zeros([64], dtype=np.float32)
        mask = np.ones([531,730], dtype=np.bool)
        img = cv2.imread(scan_path + ".jpg")
        
        # Image augmentation such as: normalization
        # TODO
        img, bboxes_2d_with_label = img_det_aug(img, bboxes_2d_with_label, self.split_set, self.dataset_config.class2type)


        width_ratio = img.shape[1] / 730
        height_ratio = img.shape[0] / 531
        
        # Make sure the image size is no large then (730,531)
        if max(width_ratio, height_ratio) > 1:
            scale = 1 / max(width_ratio, height_ratio)
            img, bboxes_2d_with_label = scale_img_bbox(img, bboxes_2d_with_label, scale)
                    
        # Padding with image
        img_width = img.shape[1]
        img_height = img.shape[0]
        image[:img_height, :img_width, :3] = img
        image = np.transpose(image,[2,0,1])
        
        mask[:img_height, :img_width] = False
        
        #Padding with box & Normalizing Box
        bboxes_2d[:bboxes_2d_with_label.shape[0],:] = bboxes_2d_with_label[:,:4]
        center_w = (bboxes_2d[:,0] + bboxes_2d[:,2])/2
        center_h = (bboxes_2d[:,1] + bboxes_2d[:,3])/2
        size_w = np.abs(bboxes_2d[:,0] - bboxes_2d[:,2])
        size_h = np.abs(bboxes_2d[:,1] - bboxes_2d[:,3])
        
        bboxes_2d[:,0] = center_w
        bboxes_2d[:,1] = center_h
        bboxes_2d[:,2] = size_w
        bboxes_2d[:,3] = size_h        
        
        bboxes_2d[:,0] /= img_width
        bboxes_2d[:,2] /= img_width
        bboxes_2d[:,1] /= img_height
        bboxes_2d[:,3] /= img_height        
        bboxes_2d_label[:bboxes_2d_with_label.shape[0]] = bboxes_2d_with_label[:,4]
        
        # Validate image data
        '''
        for ind in range(bboxes_2d_with_label.shape[0]):
        	top_left = (int((bboxes_2d[ind,0] - bboxes_2d[ind,2]/2) * img_width), int((bboxes_2d[ind,1] - bboxes_2d[ind,3]/2) * img_height))
        	down_right = (int((bboxes_2d[ind,0] + bboxes_2d[ind,2]/2) * img_width), int((bboxes_2d[ind,1] + bboxes_2d[ind,3]/2) * img_height))
        	print(top_left)
        	print(down_right)
        	cv2.rectangle(img, top_left, down_right, (0,255,0), 2)
        	print('%d %s'%(ind, self.dataset_config.class2type[bboxes_2d_with_label[ind,4]]))
        	cv2.putText(img, '%d %s'%(ind, self.dataset_config.class2type[bboxes_2d_with_label[ind,4]]), (max(int(bboxes_2d_with_label[ind,0]),15), max(int(bboxes_2d_with_label[ind,1]),15)), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,0,0), 2)
        
        cv2.imwrite("test.jpg", img)
        input()
        #cv2.imshow("hehe", img)
        #cv2.waitKey(0)
        #exit()
        '''
        
        if not self.use_color:
            point_cloud = point_cloud[:, 0:3]
        else:
            assert point_cloud.shape[1] == 6
            point_cloud = point_cloud[:, 0:6]
            point_cloud[:, 3:] = point_cloud[:, 3:] - MEAN_COLOR_RGB

        if self.use_height:
            floor_height = np.percentile(point_cloud[:, 2], 0.99)
            height = point_cloud[:, 2] - floor_height
            point_cloud = np.concatenate(
                [point_cloud, np.expand_dims(height, 1)], 1
            )  # (N,4) or (N,7)

        # ------------------------------- DATA AUGMENTATION ------------------------------
        if self.augment:
            if np.random.random() > 0.5:
                # Flipping along the YZ plane
                point_cloud[:, 0] = -1 * point_cloud[:, 0]
                bboxes[:, 0] = -1 * bboxes[:, 0]
                bboxes[:, 6] = np.pi - bboxes[:, 6]

            # Rotation along up-axis/Z-axis
            rot_angle = (np.random.random() * np.pi / 3) - np.pi / 6  # -30 ~ +30 degree
            rot_mat = pc_util.rotz(rot_angle)

            point_cloud[:, 0:3] = np.dot(point_cloud[:, 0:3], np.transpose(rot_mat))
            bboxes[:, 0:3] = np.dot(bboxes[:, 0:3], np.transpose(rot_mat))
            bboxes[:, 6] -= rot_angle

            # Augment RGB color
            if self.use_color:
                rgb_color = point_cloud[:, 3:6] + MEAN_COLOR_RGB
                rgb_color *= (
                    1 + 0.4 * np.random.random(3) - 0.2
                )  # brightness change for each channel
                rgb_color += (
                    0.1 * np.random.random(3) - 0.05
                )  # color shift for each channel
                rgb_color += np.expand_dims(
                    (0.05 * np.random.random(point_cloud.shape[0]) - 0.025), -1
                )  # jittering on each pixel
                rgb_color = np.clip(rgb_color, 0, 1)
                # randomly drop out 30% of the points' colors
                rgb_color *= np.expand_dims(
                    np.random.random(point_cloud.shape[0]) > 0.3, -1
                )
                point_cloud[:, 3:6] = rgb_color - MEAN_COLOR_RGB

            # Augment point cloud scale: 0.85x-1.15x
            scale_ratio = np.random.random() * 0.3 + 0.85
            scale_ratio = np.expand_dims(np.tile(scale_ratio, 3), 0)
            point_cloud[:, 0:3] *= scale_ratio
            bboxes[:, 0:3] *= scale_ratio
            bboxes[:, 3:6] *= scale_ratio

            if self.use_height:
                point_cloud[:, -1] *= scale_ratio[0, 0]

            if self.use_random_cuboid:
                point_cloud, bboxes, _ = self.random_cuboid_augmentor(
                    point_cloud, bboxes
                )

        # ------------------------------- LABELS ------------------------------
        angle_classes = np.zeros((self.max_num_obj,), dtype=np.float32)
        angle_residuals = np.zeros((self.max_num_obj,), dtype=np.float32)
        raw_angles = np.zeros((self.max_num_obj,), dtype=np.float32)
        raw_sizes = np.zeros((self.max_num_obj, 3), dtype=np.float32)
        label_mask = np.zeros((self.max_num_obj))
        label_mask[0 : bboxes.shape[0]] = 1
        max_bboxes = np.zeros((self.max_num_obj, 8))
        max_bboxes[0 : bboxes.shape[0], :] = bboxes

        target_bboxes_mask = label_mask
        target_bboxes = np.zeros((self.max_num_obj, 6))

        for i in range(bboxes.shape[0]):
            bbox = bboxes[i]
            semantic_class = bbox[7]
            raw_angles[i] = bbox[6] % 2 * np.pi
            box3d_size = bbox[3:6] * 2
            raw_sizes[i, :] = box3d_size
            angle_class, angle_residual = self.dataset_config.angle2class(bbox[6])
            angle_classes[i] = angle_class
            angle_residuals[i] = angle_residual
            corners_3d = self.dataset_config.my_compute_box_3d(
                bbox[0:3], bbox[3:6], bbox[6]
            )
            # compute axis aligned box
            xmin = np.min(corners_3d[:, 0])
            ymin = np.min(corners_3d[:, 1])
            zmin = np.min(corners_3d[:, 2])
            xmax = np.max(corners_3d[:, 0])
            ymax = np.max(corners_3d[:, 1])
            zmax = np.max(corners_3d[:, 2])
            target_bbox = np.array(
                [
                    (xmin + xmax) / 2,
                    (ymin + ymax) / 2,
                    (zmin + zmax) / 2,
                    xmax - xmin,
                    ymax - ymin,
                    zmax - zmin,
                ]
            )
            target_bboxes[i, :] = target_bbox

        point_cloud, choices = pc_util.random_sampling(
            point_cloud, self.num_points, return_choices=True
        )

        point_cloud_dims_min = point_cloud.min(axis=0)
        point_cloud_dims_max = point_cloud.max(axis=0)

        mult_factor = point_cloud_dims_max - point_cloud_dims_min
        box_sizes_normalized = scale_points(
            raw_sizes.astype(np.float32)[None, ...],
            mult_factor=1.0 / mult_factor[None, ...],
        )
        box_sizes_normalized = box_sizes_normalized.squeeze(0)

        box_centers = target_bboxes.astype(np.float32)[:, 0:3]
        box_centers_normalized = shift_scale_points(
            box_centers[None, ...],
            src_range=[
                point_cloud_dims_min[None, ...],
                point_cloud_dims_max[None, ...],
            ],
            dst_range=self.center_normalizing_range,
        )
        box_centers_normalized = box_centers_normalized.squeeze(0)
        box_centers_normalized = box_centers_normalized * target_bboxes_mask[..., None]

        # re-encode angles to be consistent with VoteNet eval
        angle_classes = angle_classes.astype(np.int64)
        angle_residuals = angle_residuals.astype(np.float32)
        raw_angles = self.dataset_config.class2angle_batch(
            angle_classes, angle_residuals
        )

        box_corners = self.dataset_config.box_parametrization_to_corners_np(
            box_centers[None, ...],
            raw_sizes.astype(np.float32)[None, ...],
            raw_angles.astype(np.float32)[None, ...],
        )
        box_corners = box_corners.squeeze(0)

        ret_dict = {}
        ret_dict["point_clouds"] = point_cloud.astype(np.float32)
        ret_dict["gt_box_corners"] = box_corners.astype(np.float32)
        ret_dict["gt_box_centers"] = box_centers.astype(np.float32)
        ret_dict["gt_box_centers_normalized"] = box_centers_normalized.astype(
            np.float32
        )
        target_bboxes_semcls = np.zeros((self.max_num_obj))
        target_bboxes_semcls[0 : bboxes.shape[0]] = bboxes[:, -1]  # from 0 to 9
        ret_dict["gt_box_sem_cls_label"] = target_bboxes_semcls.astype(np.int64)
        ret_dict["gt_box_present"] = target_bboxes_mask.astype(np.float32)
        ret_dict["scan_idx"] = np.array(idx).astype(np.int64)
        ret_dict["gt_box_sizes"] = raw_sizes.astype(np.float32)
        ret_dict["gt_box_sizes_normalized"] = box_sizes_normalized.astype(np.float32)
        ret_dict["gt_box_angles"] = raw_angles.astype(np.float32)
        ret_dict["gt_angle_class_label"] = angle_classes
        ret_dict["gt_angle_residual_label"] = angle_residuals
        ret_dict["point_cloud_dims_min"] = point_cloud_dims_min
        ret_dict["point_cloud_dims_max"] = point_cloud_dims_max
        ret_dict["image"] = image
        ret_dict["mask"] = mask
        ret_dict["bboxes_2d"] = bboxes_2d
        ret_dict["bboxes_2d_label"] = bboxes_2d_label.astype(np.int64)
        ret_dict["bbox_num"] = bbox_num
        ret_dict["pseudo_bbox_num"] = pseudo_bbox_num
        
        if len(self.imagenet_scan_name_dict.keys()) > 0:
            #image_imagenet,mask_imagenet,bboxes_2d_imagenet,bboxes_2d_label_imagenet,bbox_num_imagenet = self.imagenet_item(idx)
            image_imagenet,mask_imagenet,bboxes_2d_imagenet,bboxes_2d_label_imagenet,bbox_num_imagenet = self.load_n_imagenet_item(9)
            
            ret_dict["image_imagenet"] = image_imagenet
            ret_dict["mask_imagenet"] = mask_imagenet
            ret_dict["bboxes_2d_imagenet"] = bboxes_2d_imagenet
            ret_dict["bboxes_2d_label_imagenet"] = bboxes_2d_label_imagenet.astype(np.int64)
            ret_dict["bbox_num_imagenet"] = bbox_num_imagenet
        
        return ret_dict
