'''
Author: Wenhao Ding
Email: wenhaod@andrew.cmu.edu
Date: 2021-02-20 11:21:41
LastEditTime: 2021-05-06 15:29:11
Description: 
'''
import numpy as np
import numba as nb


# transformation between Cartesian coordinates and polar coordinates
def cart2polar(input_xyz):
    rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2)
    phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0])
    return np.stack((rho, phi, input_xyz[:, 2]), axis=1) # [rho, phi, z]


def polar2cat(input_xyz_polar):
    # print(input_xyz_polar.shape)
    x = input_xyz_polar[0] * np.cos(input_xyz_polar[1])
    y = input_xyz_polar[0] * np.sin(input_xyz_polar[1])
    return np.stack((x, y, input_xyz_polar[2]), axis=0)


@nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False)
def nb_process_label(processed_label, sorted_label_voxel_pair):
    label_size = 256
    counter = np.zeros((label_size,), dtype=np.uint16)
    counter[sorted_label_voxel_pair[0, 3]] = 1
    cur_sear_ind = sorted_label_voxel_pair[0, :3]
    for i in range(1, sorted_label_voxel_pair.shape[0]):
        cur_ind = sorted_label_voxel_pair[i, :3]
        if not np.all(np.equal(cur_ind, cur_sear_ind)):
            processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
            counter = np.zeros((label_size,), dtype=np.uint16)
            cur_sear_ind = cur_ind
        counter[sorted_label_voxel_pair[i, 3]] += 1
    processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
    return processed_label


def pc_to_cylinder(data, dataset_config, model_config):
    xyz, labels = data
    xyz_pol = cart2polar(xyz)

    fixed_volume_space = dataset_config['fixed_volume_space']
    max_volume_space = dataset_config['max_volume_space'],
    min_volume_space = dataset_config['min_volume_space'],
    ignore_label = dataset_config["ignore_label"],
    grid_size = np.asarray(model_config['output_shape'])

    max_bound_r = np.percentile(xyz_pol[:, 0], 100, axis=0)
    min_bound_r = np.percentile(xyz_pol[:, 0], 0, axis=0)
    max_bound = np.max(xyz_pol[:, 1:], axis=0)
    min_bound = np.min(xyz_pol[:, 1:], axis=0)
    max_bound = np.concatenate(([max_bound_r], max_bound))
    min_bound = np.concatenate(([min_bound_r], min_bound))
    if fixed_volume_space:
        max_bound = np.asarray(max_volume_space)
        min_bound = np.asarray(min_volume_space)
    # get grid index
    crop_range = max_bound - min_bound
    cur_grid_size = grid_size
    intervals = crop_range / (cur_grid_size - 1)

    if (intervals == 0).any(): print("Zero interval!")
    grid_ind = (np.floor((np.clip(xyz_pol, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)

    voxel_position = np.zeros(grid_size, dtype=np.float32)
    dim_array = np.ones(len(grid_size) + 1, int)
    dim_array[0] = -1
    voxel_position = np.indices(grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array)
    voxel_position = polar2cat(voxel_position)

    processed_label = np.ones(grid_size, dtype=np.uint8) * np.uint8(ignore_label)
    label_voxel_pair = np.concatenate([grid_ind, labels], axis=1)
    label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :]
    processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair)
    data_tuple = (voxel_position, processed_label)

    # center data on each voxel for PTnet
    voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
    return_xyz = xyz_pol - voxel_centers
    return_xyz = np.concatenate((return_xyz, xyz_pol, xyz[:, :2]), axis=1)

    return_fea = return_xyz # [3+3+2] = [[relative rho, relative phi, relative z], [rho, phi, z], [x, y]]
    data_tuple += (grid_ind, labels, return_fea)
    return data_tuple
