# -*- coding: utf-8 -*-
# Author: Runsheng Xu <rxx3386@ucla.edu>, Hao Xiang <haxiang@g.ucla.edu>,
# License: TDG-Attribution-NonCommercial-NoDistrib


"""
Utility functions related to point cloud
"""

import open3d as o3d
import numpy as np
from pypcd import pypcd


def pcd_to_np(pcd_file):
    """
    Read  pcd and return numpy array.

    Parameters
    ----------
    pcd_file : str
        The pcd file that contains the point cloud.

    Returns
    -------
    pcd : o3d.PointCloud
        PointCloud object, used for visualization
    pcd_np : np.ndarray
        The lidar data in numpy format, shape:(n, 4)

    """
    pcd = o3d.io.read_point_cloud(pcd_file)

    xyz = np.asarray(pcd.points)
    # we save the intensity in the first channel
    intensity = np.expand_dims(np.asarray(pcd.colors)[:, 0], -1)
    pcd_np = np.hstack((xyz, intensity))

    return np.asarray(pcd_np, dtype=np.float32)


def mask_points_by_range(points, limit_range):
    """
    Remove the lidar points out of the boundary.

    Parameters
    ----------
    points : np.ndarray
        Lidar points under lidar sensor coordinate system.

    limit_range : list
        [x_min, y_min, z_min, x_max, y_max, z_max]

    Returns
    -------
    points : np.ndarray
        Filtered lidar points.
    """
    adaptive_range = limit_range['lidar_range'] if 'lidar_range' in limit_range \
        else limit_range

    mask = (points[:, 0] > adaptive_range[0]) & (points[:, 0] < adaptive_range[3])\
           & (points[:, 1] > adaptive_range[1]) & (points[:, 1] < adaptive_range[4]) \
           & (points[:, 2] > adaptive_range[2]) & (points[:, 2] < adaptive_range[5])

    points = points[mask]

    return points

def v2xsim_mask_points_by_range_late(points, cav_id):
    """
    Remove the lidar points out of the boundary.

    Parameters
    ----------
    points : np.ndarray
        Lidar points under lidar sensor coordinate system.

    limit_range : list
        [x_min, y_min, z_min, x_max, y_max, z_max]

    Returns
    -------
    points : np.ndarray
        Filtered lidar points.
    """
    adaptive_range = [-32, -32, -3, 32, 32, 2] if int(cav_id) > 0 else\
        [-32, -32, -8.5, 32, 32, -3.5]
    
    # adaptive_range = limit_range['lidar_range'] if 'lidar_range' in limit_range \
    #     else limit_range

    mask = (points[:, 0] > adaptive_range[0]) & (points[:, 0] < adaptive_range[3])\
           & (points[:, 1] > adaptive_range[1]) & (points[:, 1] < adaptive_range[4]) \
           & (points[:, 2] > adaptive_range[2]) & (points[:, 2] < adaptive_range[5])

    points = points[mask]

    return points


def mask_ego_points(points, cav_id = None, dataset=None):
    """
    Remove the lidar points of the ego vehicle itself.

    Parameters
    ----------
    points : np.ndarray
        Lidar points under lidar sensor coordinate system.

    Returns
    -------
    points : np.ndarray
        Filtered lidar points.
    """
    # 임시 수정함 향후 수정 필
    
    if dataset != None and dataset == 'v2x-sim':
        if cav_id != None and int(cav_id) == 0:
            mask = np.sqrt(points[:, 0] ** 2 + points[:, 1] ** 2) <= 1.5
            
            # (points[:, 0] >= -1) & (points[:, 0] <= 1) \
            #     & (points[:, 1] >= -1) & (points[:, 1] <= 1)
        else:
            mask = (points[:, 0] >= -2) & (points[:, 0] <= 2) \
                & (points[:, 1] >= -1.1) & (points[:, 1] <= 1.1)

    
    else:
        if cav_id != None and int(cav_id) < 0:
            mask = (points[:, 0] >= -1) & (points[:, 0] <= 1) \
                & (points[:, 1] >= -1) & (points[:, 1] <= 1)
        else:
            mask = (points[:, 0] >= -1.95) & (points[:, 0] <= 2.95) \
                & (points[:, 1] >= -1.1) & (points[:, 1] <= 1.1)
    points = points[np.logical_not(mask)]

    return points


def shuffle_points(points):
    shuffle_idx = np.random.permutation(points.shape[0])
    points = points[shuffle_idx]

    return points


def lidar_project(lidar_data, extrinsic):
    """
    Given the extrinsic matrix, project lidar data to another space.

    Parameters
    ----------
    lidar_data : np.ndarray
        Lidar data, shape: (n, 4)

    extrinsic : np.ndarray
        Extrinsic matrix, shape: (4, 4)

    Returns
    -------
    projected_lidar : np.ndarray
        Projected lida data, shape: (n, 4)
    """

    lidar_xyz = lidar_data[:, :3].T
    # (3, n) -> (4, n), homogeneous transformation
    lidar_xyz = np.r_[lidar_xyz, [np.ones(lidar_xyz.shape[1])]]
    lidar_int = lidar_data[:, 3]

    # transform to ego vehicle space, (3, n)
    project_lidar_xyz = np.dot(extrinsic, lidar_xyz)[:3, :]
    # (n, 3)
    project_lidar_xyz = project_lidar_xyz.T
    # concatenate the intensity with xyz, (n, 4)
    projected_lidar = np.hstack((project_lidar_xyz,
                                 np.expand_dims(lidar_int, -1)))

    return projected_lidar


def projected_lidar_stack(projected_lidar_list):
    """
    Stack all projected lidar together.

    Parameters
    ----------
    projected_lidar_list : list
        The list containing all projected lidar.

    Returns
    -------
    stack_lidar : np.ndarray
        Stack all projected lidar data together.
    """
    stack_lidar = []
    for lidar_data in projected_lidar_list:
        stack_lidar.append(lidar_data)

    return np.vstack(stack_lidar)


def downsample_lidar(pcd_np, num):
    """
    Downsample the lidar points to a certain number.

    Parameters
    ----------
    pcd_np : np.ndarray
        The lidar points, (n, 4).

    num : int
        The downsample target number.

    Returns
    -------
    pcd_np : np.ndarray
        The downsampled lidar points.
    """
    assert pcd_np.shape[0] >= num

    selected_index = np.random.choice((pcd_np.shape[0]),
                                      num,
                                      replace=False)
    pcd_np = pcd_np[selected_index]

    return pcd_np


def downsample_lidar_minimum(pcd_np_list):
    """
    Given a list of pcd, find the minimum number and downsample all
    point clouds to the minimum number.

    Parameters
    ----------
    pcd_np_list : list
        A list of pcd numpy array(n, 4).
    Returns
    -------
    pcd_np_list : list
        Downsampled point clouds.
    """
    minimum = np.Inf

    for i in range(len(pcd_np_list)):
        num = pcd_np_list[i].shape[0]
        minimum = num if minimum > num else minimum

    for (i, pcd_np) in enumerate(pcd_np_list):
        pcd_np_list[i] = downsample_lidar(pcd_np, minimum)

    return pcd_np_list


def read_pcd(pcd_path):
    # x y z r g b intensity timestamp 每个点包含哪些维度
    # xyz表示XYZ三维坐标，rgb表示颜色（可以分开表示，也可以一个浮点数表示），
    # intensity表示激光反射强度，timestamp表示时间戳，normal_x、normal_y、normal_z表示平面法线三维坐标，j1、j2、j3表示不变矩。
    pcd = pypcd.PointCloud.from_path(pcd_path)
    # print(pcd.metadata_keys)
    # print(pcd.pc_data)
    time = None
    pcd_np_points = np.zeros((pcd.points, 4), dtype=np.float32)
    pcd_np_points[:, 0] = np.transpose(pcd.pc_data["x"])
    pcd_np_points[:, 1] = np.transpose(pcd.pc_data["y"])
    pcd_np_points[:, 2] = np.transpose(pcd.pc_data["z"])
    pcd_np_points[:, 3] = np.transpose(pcd.pc_data["intensity"]) / 256.0
    del_index = np.where(np.isnan(pcd_np_points))[0]
    pcd_np_points = np.delete(pcd_np_points, del_index, axis=0)
    return pcd_np_points, time