'''
@Author: Wenhao Ding
@Email: wenhaod@andrew.cmu.edu
@Date: 2020-07-09 13:51:09
LastEditTime: 2020-12-15 21:17:31
@Description: modified from https://github.com/PRBonn/semantic-kitti-api
'''

import os
import numpy as np
import yaml

import torch
import torch.utils.data as torch_data

from utils import COLOR, save_ply, read_ply, pc_to_rangemap


class CarlaDataset(torch_data.Dataset):
    def __init__(self, args):
        self.root_dir = args.carla_path
        self.file_num = len(os.listdir(self.root_dir))

        # input pointcloud number and fov
        self.npoints = args.npoints
        self.upper_fov = args.upper_fov
        self.lower_fov = args.lower_fov
        self.left_fov = args.left_fov
        self.right_fov = args.right_fov
        self.width = args.width
        self.height = args.height
        self.max_range = args.max_range

        print(COLOR.GREEN+'Carla Dataset Info:')
        print('\tNumber of samples:', self.file_num)
        print('\tNumber of points:', self.width*self.height)
        print('\tClip FOV (left, right):', [self.left_fov, self.right_fov])
        print(COLOR.WHITE+'')

    def __len__(self):
        return self.file_num

    def get_background(self, name='carla'):
        """ Get an empty pointcloud background from carla for syntehsizing.
            The output is in carla coordinate, since we still need to modify it.
        """
        if name == 'kitti':
            ply_name = os.path.join(self.root_dir, 'background_kitti.ply')
            xyzrgb = read_ply(ply_name)
        else:
            ply_name = os.path.join(self.root_dir, 'empty.ply')
            xyzrgb = read_ply(ply_name)
            #pts_input = xyzrgb[:, 0:3] + np.array([0.0, 0.0, 0.4])
            # TODO: add some noise to the background

        # background has no vehicle label
        pts_input = xyzrgb[:, 0:3]

        # DEBUG:
        #print('minimal z value: ', np.min(pts_input[:, 2]))

        # convert the cropped points into range map
        fov_horizon_range = self.right_fov - self.left_fov
        range_map = pc_to_rangemap(pts_input, fov_horizon_range, self.lower_fov, self.upper_fov, self.width, self.height, self.max_range)
        return range_map.T

    @staticmethod
    def kitti_to_carla_coord(kitti_xyz, type='numpy'):
        """ Transfer the coordinate from carla to Kitti (label coordinate).
            This relationship depend on how we crop the carla lidar.
            Carla:  x, y, z
            Kitti: -z, x, -y
        """
        
        carla_x = -kitti_xyz[:, 2:3]
        carla_y = kitti_xyz[:, 0:1]
        carla_z = -kitti_xyz[:, 1:2]
        carla_xyz = np.concatenate([carla_x, carla_y, carla_z], axis=1)
        return carla_xyz

    @staticmethod
    def carla_to_kitti_coord(carla_xyz, type='numpy'):
        """ Transfer the coordinate from carla to Kitti (label coordinate).
            This relationship depend on how we crop the carla lidar.
            Kitti: x, y, z
            Carla: y, -z, -x
        """

        kitti_x = carla_xyz[:, 1:2]
        kitti_y = -carla_xyz[:, 2:3]
        kitti_z = -carla_xyz[:, 0:1]

        if type == 'numpy':
            kitti_xyz = np.concatenate([kitti_x, kitti_y, kitti_z], axis=1)
        elif type == 'torch':
            kitti_xyz = torch.cat([kitti_x, kitti_y, kitti_z], dim=1)
        return kitti_xyz
        
    def generate_fov_pointcloud(self, xyzrgb):
        """ Use the lower and upper bound of FOV to crop the pointcloud.
        """
        x = xyzrgb[:, 0]
        y = xyzrgb[:, 1]
        horitontal_angle = np.arctan2(x, y)
        horitontal_angle -= np.min(horitontal_angle)  # make sure the angle falls into [0, 360]
        horitontal_angle = np.rad2deg(horitontal_angle)
        xyzrgb = xyzrgb[np.logical_and(horitontal_angle < self.right_fov, horitontal_angle > self.left_fov), :]
        return xyzrgb

    @staticmethod
    def generate_label(xyzrgb):
        """ Generate the label of vehicle according to the color information
        """
        xyz = xyzrgb[:, 0:3]
        rgb = xyzrgb[:, 3:6]
        label = np.zeros((rgb.shape[0],))
        # the color of vehicle is blue [0, 0, 142]
        for c_i in range(rgb.shape[0]):
            if rgb[c_i, 0] == 0 and rgb[c_i, 1] == 0 and rgb[c_i, 2] == 142:
                label[c_i] = 1
        return xyz, label

    def point_completion(self, xyzrgb):
        # if point number is less than predefined, randomly repeat existing points
        if self.npoints > 2*len(xyzrgb):
            RuntimeError('np.random.choice will fail if the number of points is less than half of the self.npoints')
        if self.npoints > len(xyzrgb):
            choice = np.arange(0, len(xyzrgb), dtype=np.int32)
            extra_choice = np.random.choice(choice, self.npoints - len(xyzrgb), replace=False)
            choice = np.concatenate((choice, extra_choice), axis=0)
            np.random.shuffle(choice)
            xyzrgb = xyzrgb[choice, :]
            return xyzrgb
        else:
            NotImplementedError()

    def __getitem__(self, index):
        ply_name = os.path.join(self.root_dir, 'carla_eval', str(index) + '.ply')
        xyzrgb = read_ply(ply_name)
        xyzrgb = self.generate_fov_pointcloud(xyzrgb)

        # complete the pointcloud to match pre-defined number (with rgb as label)
        xyzrgb = self.point_completion(xyzrgb)

        pts_input, cls_labels = self.generate_label(xyzrgb)
        # NOTE: for Semantic-Kitti dataset, there is no need to transfer the coordinate
        #pts_input = self.carla_to_kitti_coord(pts_input)
        return pts_input, cls_labels


class SemanticKittiDataset(torch_data.Dataset):
    def __init__(self, args, split='train'):
        self.split = split
        self.data_list_name = os.path.join(args.kitti_path, self.split+'_data_list.txt')
        self.label_list_name = os.path.join(args.kitti_path, self.split+'_label_list.txt')
        self.config_name = os.path.join(args.kitti_path, 'semantic-kitti.yaml')

        self.data_list = [x.strip() for x in open(self.data_list_name).readlines()]
        self.label_list = [x.strip() for x in open(self.label_list_name).readlines()]

        assert self.data_list.__len__() == self.label_list.__len__()
        self.num_sample = self.data_list.__len__()
        self.npoints = args.npoints

        # a whitelist for training label
        data_config = yaml.safe_load(open(self.config_name, 'r'))    
        learning_map = data_config["learning_map"]
        learning_ignore = data_config["learning_ignore"]
        self.label_map_dic = self.process_label_map(learning_map, learning_ignore)

        # print info
        print(COLOR.GREEN+'Semantic Kitti Dataset Info (' + split + '):')
        print('\tNumber of points:', self.npoints)
        print('\tNumber of samples:', self.num_sample)
        print(COLOR.WHITE+'')

    @staticmethod
    def process_label_map(learning_map, learning_ignore):
        # set the ignored class to 0, and re-number the valid classes
        # FIXME: we should also consider the inverse map
        for l_i in learning_map.keys():
            if learning_map[l_i] == 1:  # car
                learning_map[l_i] = 1
            elif learning_map[l_i] == 6:  # people
                learning_map[l_i] = 0
            else:
                learning_map[l_i] = 0
        
        # make lookup table for mapping
        maxkey = max(learning_map.keys())
        # +100 hack making lut bigger just in case there are unknown labels
        remap_lut = np.zeros((maxkey + 100), dtype=np.int32)
        remap_lut[list(learning_map.keys())] = list(learning_map.values())
        return remap_lut

    def get_lidar(self, idx):
        lidar_file = self.data_list[idx]
        assert os.path.exists(lidar_file)
        one_scan = np.fromfile(lidar_file, dtype=np.float32).reshape((-1, 4))
        return one_scan[:, 0:3]

    def get_label(self, idx):
        """ Open raw scan and fill in attributes
        """
        # check filename is string
        label_file = self.label_list[idx]
        assert os.path.exists(label_file)
        label = np.fromfile(label_file, dtype=np.uint32)
        label = label.reshape((-1))

        # process data
        #upper_half = label >> 16                     # get upper half for instances
        lower_half = label & 0xFFFF                   # get lower half for semantics
        sem_label = self.label_map_dic[lower_half]    # do the remapping of semantics
        #label = (upper_half << 16) + lower_half      # reconstruct full label
        return sem_label.astype(np.int32)

    def __len__(self):
        return self.num_sample

    def __getitem__(self, index):
        points = self.get_lidar(index)
        sem_label = self.get_label(index)
        
        # filter far points (not many points as I expected)
        pts_depth = (points[:, 0]**2 + points[:, 1]**2 + points[:, 2]**2)**0.5
        pts_near_flag = pts_depth < 60
        near_idxs = np.where(pts_near_flag == 1)[0]
        filtered_points = points[near_idxs]
        filtered_labels = sem_label[near_idxs]
        current_npoints = len(filtered_points)

        '''
        # DEBUG:
        print(max(filtered_labels))
        rgb = np.zeros_like(filtered_points)
        for i in range(filtered_labels.shape[0]):
            if filtered_labels[i] == 0:
                rgb[i] = [255, 255, 255]
            elif filtered_labels[i] == 1:
                rgb[i] = [255, 0, 0]
            elif filtered_labels[i] == 2:
                rgb[i] = [0, 255, 0]
        xyzrgb = np.concatenate([filtered_points, rgb], axis=1)
        save_ply('./log/semantic_kitti_sample.ply', xyzrgb)
        input()
        '''

        # point number should be less than predefined, randomly repeat existing points
        assert self.npoints >= current_npoints, 'Semantic Kitti should use 130000 points, and Argoverse uses 100000'
        choice = np.arange(0, current_npoints, dtype=np.int32)
        extra_choice = np.random.choice(choice, self.npoints - current_npoints, replace=False)
        choice = np.concatenate((choice, extra_choice), axis=0)

        np.random.shuffle(choice)
        shuffled_points = filtered_points[choice]
        shuffled_labels = filtered_labels[choice]

        # generate training labels
        #sample_info = {}
        #sample_info['pts_input'] = shuffled_points
        #sample_info['cls_labels'] = shuffled_labels
        return shuffled_points, shuffled_labels

    '''
    def collate_batch(self, batch):
        batch_size = batch.__len__()
        ans_dict = {}

        for key in batch[0].keys():
            if isinstance(batch[0][key], np.ndarray):
                ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0)
            else:
                ans_dict[key] = [batch[k][key] for k in range(batch_size)]
                if isinstance(batch[0][key], int):
                    ans_dict[key] = np.array(ans_dict[key], dtype=np.int32)
                elif isinstance(batch[0][key], float):
                    ans_dict[key] = np.array(ans_dict[key], dtype=np.float32)
        return ans_dict
    '''


class ArgoverseDataset(torch_data.Dataset):
    def __init__(self, args, split='train'):
        self.data_path = os.path.join(args.argo_path, split)
        self.data_label_list = os.listdir(self.data_path)

        self.num_sample = self.data_label_list.__len__()
        self.npoints = args.npoints

        # print info
        print(COLOR.GREEN+'Argoverse Dataset Info (' + split + '):')
        print('\tNumber of points:', self.npoints)
        print('\tNumber of samples:', self.num_sample)
        print(COLOR.WHITE+'')

    def get_lidar_label(self, idx):
        file_name = self.data_label_list[idx]
        file_path = os.path.join(self.data_path, file_name)
        assert os.path.exists(file_path)
        one_scan = np.fromfile(file_path).reshape((-1, 5))
        assert one_scan.shape[1] == 5
        lidar = one_scan[:, 0:3]
        sem_label = one_scan[:, 3]
        #ins_lbel = one_scan[:, 4]
        return lidar, sem_label

    def __len__(self):
        return self.num_sample

    def __getitem__(self, index):
        points, sem_label = self.get_lidar_label(index)

        # filter far points (not many outlier points as I expected)
        pts_depth = (points[:, 0]**2 + points[:, 1]**2 + points[:, 2]**2)**0.5
        pts_near_flag = pts_depth < 60
        near_idxs = np.where(pts_near_flag == 1)[0]
        filtered_points = points[near_idxs]
        filtered_labels = sem_label[near_idxs]
        current_npoints = len(filtered_points)

        '''
        # DEBUG:
        print(max(filtered_labels))
        rgb = np.zeros_like(filtered_points)
        for i in range(filtered_labels.shape[0]):
            if filtered_labels[i] == 0:
                rgb[i] = [255, 255, 255]
            elif filtered_labels[i] == 1:
                rgb[i] = [255, 0, 0]
            elif filtered_labels[i] == 2:
                rgb[i] = [0, 255, 0]
        xyzrgb = np.concatenate([filtered_points, rgb], axis=1)
        save_ply('./log/semantic_kitti_sample.ply', xyzrgb)
        input()
        '''
        
        # point number is less than predefined, randomly repeat existing points
        if current_npoints < self.npoints:
            # NOTE: we should make sure the number of points is larger than half of predefined points
            assert current_npoints > self.npoints/2
            choice = np.arange(0, current_npoints, dtype=np.int32)
            extra_choice = np.random.choice(choice, self.npoints - current_npoints, replace=False)
            choice = np.concatenate((choice, extra_choice), axis=0)
        # point number is larger than predefined, randomly delete some points
        else:
            choice = np.arange(0, current_npoints, dtype=np.int32)
            choice = np.random.choice(choice, self.npoints, replace=False)

        np.random.shuffle(choice)
        shuffled_points = filtered_points[choice]
        shuffled_labels = filtered_labels[choice]

        # generate training labels
        return shuffled_points, shuffled_labels


class ProcessDataset(torch_data.Dataset):
    """ This class is used to process the original semantic-kitti dataset.
    """
    def __init__(self, args, split='train'):
        # input pointcloud number and fov
        self.upper_fov = args.upper_fov
        self.lower_fov = args.lower_fov
        self.left_fov = args.left_fov
        self.right_fov = args.right_fov
        self.width = args.width
        self.height = args.height
        self.max_range = args.max_range
        
        self.split = split
        self.data_list_name = os.path.join(args.kitti_path, self.split+'_data_list.txt')
        self.label_list_name = os.path.join(args.kitti_path, self.split+'_label_list.txt')
        self.config_name = os.path.join(args.kitti_path, 'semantic-kitti.yaml')

        self.data_list = [x.strip() for x in open(self.data_list_name).readlines()]
        self.label_list = [x.strip() for x in open(self.label_list_name).readlines()]

        assert self.data_list.__len__() == self.label_list.__len__()
        self.num_sample = self.data_list.__len__()
        self.npoints = args.npoints

        # a whitelist for training label
        data_config = yaml.safe_load(open(self.config_name, 'r'))    
        learning_map = data_config["learning_map"]
        learning_ignore = data_config["learning_ignore"]
        self.label_map_dic = self.process_label_map(learning_map, learning_ignore)

        # print info
        print(COLOR.GREEN+'Vehicle Dataset Info (' + split + '):')
        print('\tNumber of points:', self.npoints)
        print('\tNumber of samples:', self.num_sample)
        print(COLOR.WHITE+'')

    @staticmethod
    def process_label_map(learning_map, learning_ignore):
        # set the ignored class to 0, and re-number the valid classes
        # FIXME: we should also consider the inverse map
        for l_i in learning_map.keys():
            if learning_map[l_i] == 1:  # car
                learning_map[l_i] = 1
            else:
                learning_map[l_i] = 0
        
        # make lookup table for mapping
        maxkey = max(learning_map.keys())
        # +100 hack making lut bigger just in case there are unknown labels
        remap_lut = np.zeros((maxkey + 100), dtype=np.int32)
        remap_lut[list(learning_map.keys())] = list(learning_map.values())
        return remap_lut

    def get_lidar(self, idx):
        lidar_file = self.data_list[idx]
        assert os.path.exists(lidar_file)
        one_scan = np.fromfile(lidar_file, dtype=np.float32).reshape((-1, 4))
        return one_scan[:, 0:3]

    def get_label(self, idx):
        """ Open raw scan and fill in attributes
        """
        # check filename is string
        label_file = self.label_list[idx]
        assert os.path.exists(label_file)
        label = np.fromfile(label_file, dtype=np.uint32)
        label = label.reshape((-1))

        # process data
        upper_half = label >> 16                      # get upper half for instances
        lower_half = label & 0xFFFF                   # get lower half for semantics
        sem_label = self.label_map_dic[lower_half]    # do the remapping of semantics
        inst_label = upper_half                       # instance label
        return sem_label.astype(np.int32), inst_label.astype(np.int32)

    def __len__(self):
        return self.num_sample
    
    def _invalid_return(self):
        vehicle_points = np.zeros((1, 3))
        vehicle_label = -1*np.ones((1,), dtype=np.int32)
        return vehicle_points, vehicle_label
            
    def __getitem__(self, index):
        index = 1602
        points = self.get_lidar(index)
        sem_label, inst_label = self.get_label(index)

        # get the points of vehicle
        vehicle_idx = sem_label == 1
        vehicle_points = points[vehicle_idx, :]
        vehicle_label = inst_label[vehicle_idx]

        # check if there is no vehicle points
        if vehicle_label.shape[0] == 0:
            return self._invalid_return()

        # delete the points that are too close to the lidar
        range_xyz = (vehicle_points[:, 0]**2 + vehicle_points[:, 1]**2)**0.5
        maxrange_idx = range_xyz > 10
        vehicle_points = vehicle_points[maxrange_idx, :]
        vehicle_label = vehicle_label[maxrange_idx]

        # sort the labels according to their number of points
        point_threshold = 300
        select_number = 1
        all_id = np.unique(vehicle_label)
        candidate_number = []
        candidate_id = []
        for a_i in all_id:
            label_mask = (vehicle_label == a_i)
            number = np.sum(label_mask)
            if number > point_threshold:
                candidate_number.append(-number) # use negative to srot from high to low
                candidate_id.append(a_i)
        # if total instance number is less than the selected number
        if len(candidate_number) < select_number:
            return self._invalid_return()
        
        sorted_idx = np.argsort(candidate_number)
        sorted_id = np.array(candidate_id)[sorted_idx]
        selected_id = sorted_id[0:select_number]
        total_mask = np.zeros_like(vehicle_label)
        for s_i in selected_id:
            label_mask = (vehicle_label == s_i)
            total_mask = np.logical_or(total_mask, label_mask)
        vehicle_points = vehicle_points[total_mask, :]
        vehicle_label = vehicle_label[total_mask]

        '''
        # DEBUG:
        # instance color mapping
        inst_color_lut = np.random.randint(low=0, high=255, size=(np.max(vehicle_label)+1, 3))
        vehicle_color = inst_color_lut[vehicle_label]
        print(vehicle_color.shape)
        xyzrgb = np.concatenate([vehicle_points, vehicle_color], axis=1)
        save_ply('../log/test.ply', xyzrgb)
        '''

        # get the range map
        horizon_range = self.right_fov - self.left_fov
        range_map, label_map = pc_to_rangemap(
            vehicle_points, horizon_range, self.lower_fov, self.upper_fov, self.width, self.height, self.max_range, vehicle_label)
        feature_map = np.concatenate([range_map[None], label_map[None]])

        # NOTE: here we dont upsample the points, since the output from the render also has no fixed length
        # [N, 3], [2, 64, 2048]

        # DEBUG: opposite the x axis
        #vehicle_points[:, 0] = -vehicle_points[:, 0]
        #vehicle_points[:, 0] += 13
        #vehicle_points[:, 1] -= 1.7

        return vehicle_points, feature_map, index


        '''
        # NOTE: delete the instance that has too few points
        point_threshold = 300
        # delete instance that has too few points
        all_id = np.unique(vehicle_label)
        threshold_mask = np.zeros_like(vehicle_label)
        for a_i in all_id:
            label_mask = (vehicle_label == a_i)
            number = np.sum(label_mask)
            if number > point_threshold:
                threshold_mask = np.logical_or(threshold_mask, label_mask)
        vehicle_label = vehicle_label[threshold_mask]

        if vehicle_label.shape[0] == 0:
            return -1*np.ones((1,), dtype=np.int32)
            
        # reorder the index, start from 1
        instance_label_mapping = {v_i: c_i+1 for c_i, v_i in enumerate(np.unique(vehicle_label))}
        vehicle_label = np.vectorize(instance_label_mapping.get)(vehicle_label)
        return vehicle_label[None]
        '''

        '''
        # if no vehicle exists, set a far point
        if vehicle_points.shape[0] == 0:
            vehicle_points = self.max_range*np.ones((1, 3))
            vehicle_label = -1*np.ones((1, 3), dtype=np.int32)

        horizon_range = self.right_fov - self.left_fov
        range_map, label_map = pc_to_rangemap(vehicle_points, horizon_range, self.lower_fov, self.upper_fov, self.width, self.height, self.max_range, vehicle_label)
        # [1, 64, 2048], [1, 64, 2048]
        return range_map[None], label_map[None]
        '''


# for debug
class FilePointcloudDataset(torch_data.Dataset):
    def __init__(self, path='./data/vehicle_n10_n10.ply'):
        self.path = path
        self.xyz = read_ply(self.path)

        lower_fov = -25.0
        upper_fov = 2.0
        left_fov = -180
        right_fov = 180
        max_range = 20
        width = 2048
        height = 64

        horizon_range = right_fov - left_fov
        self.range_map = pc_to_rangemap(self.xyz, horizon_range, lower_fov, upper_fov, width, height, max_range, None)
        self.range_map = self.range_map/20

        print(COLOR.GREEN+'Pointcloud from file:')
        print('\tFilename:', self.path)
        print('\tNumber of points:', self.xyz.shape[0])
        print(COLOR.WHITE+'')

    def __len__(self):
        return 1

    def __getitem__(self, index):
        return self.xyz, self.range_map[None] # [N, 3], [1, 64, 2048]


class VehiclePoseDataset(torch_data.Dataset):
    def __init__(self, path='./data/vehicle_pose_dataset.npy'):
        self.path = path
        self.xyz = np.load(self.path, allow_pickle=True)  # [N, P, 3]

        lower_fov = -25.0
        upper_fov = 2.0
        left_fov = -180
        right_fov = 180
        max_range = 40
        width = 2048
        height = 64
        horizon_range = right_fov - left_fov

        self.all_range_map = []
        for p_i in range(self.xyz.shape[0]):
            range_map = pc_to_rangemap(self.xyz[p_i], horizon_range, lower_fov, upper_fov, width, height, max_range, None)
            range_map = range_map/max_range
            self.all_range_map.append(range_map[None])
        self.all_range_map = np.array(self.all_range_map)  # [N, 1, 64, 2048]

        print(COLOR.GREEN+'Vehicle Pose Dataset:')
        print('\tFilename:', self.path)
        print('\tNumber of points:', self.xyz.shape[0])
        print(COLOR.WHITE+'')

    def __len__(self):
        return len(self.xyz)

    def __getitem__(self, index):
        return self.xyz[index], self.all_range_map[index]
