'''
Author: Wenhao Ding
Email: wenhaod@andrew.cmu.edu
Date: 2021-02-20 11:21:41
LastEditTime: 2021-05-06 15:59:58
Description: 
'''
import numpy as np
import numba as nb

import torch

# transformation between Cartesian coordinates and polar coordinates
def cart2polar(input_xyz):
    rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2)
    phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0])
    return np.stack((rho, phi, input_xyz[:, 2]), axis=1)


def polar2cat(input_xyz_polar):
    # print(input_xyz_polar.shape)
    x = input_xyz_polar[0] * np.cos(input_xyz_polar[1])
    y = input_xyz_polar[0] * np.sin(input_xyz_polar[1])
    return np.stack((x, y, input_xyz_polar[2]), axis=0)


@nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False)
def nb_process_label(processed_label, sorted_label_voxel_pair):
    label_size = 256
    counter = np.zeros((label_size,), dtype=np.uint16)
    counter[sorted_label_voxel_pair[0, 3]] = 1
    cur_sear_ind = sorted_label_voxel_pair[0, :3]
    for i in range(1, sorted_label_voxel_pair.shape[0]):
        cur_ind = sorted_label_voxel_pair[i, :3]
        if not np.all(np.equal(cur_ind, cur_sear_ind)):
            processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
            counter = np.zeros((label_size,), dtype=np.uint16)
            cur_sear_ind = cur_ind
        counter[sorted_label_voxel_pair[i, 3]] += 1
    processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
    return processed_label


def voxel_dataset(data, config):
    fixed_volume_space = config['fixed_volume_space']
    max_volume_space = config['max_volume_space'],
    min_volume_space = config['min_volume_space'],
    ignore_label = config["ignore_label"],
    grid_size = np.asarray(config['grid_size'])

    xyz, labels = data

    max_bound = np.percentile(xyz, 100, axis=0)
    min_bound = np.percentile(xyz, 0, axis=0)
    
    if fixed_volume_space:
        max_bound = np.asarray(max_volume_space)
        min_bound = np.asarray(min_volume_space)

    # get grid index
    crop_range = max_bound - min_bound
    cur_grid_size = grid_size
    intervals = crop_range / (cur_grid_size - 1)

    if (intervals == 0).any(): print("Zero interval!")
    grid_ind = (np.floor((np.clip(xyz,min_bound,max_bound)-min_bound)/intervals)).astype(np.int)

    # process voxel position
    voxel_position = np.zeros(grid_size, dtype=np.float32)
    dim_array = np.ones(len(grid_size)+1, int)
    dim_array[0] = -1 
    voxel_position = np.indices(grid_size)*intervals.reshape(dim_array) + min_bound.reshape(dim_array)

    # process labels
    processed_label = np.ones(grid_size, dtype=np.uint8) * np.uint8(ignore_label)
    label_voxel_pair = np.concatenate([grid_ind,labels], axis=1)
    label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:,0], grid_ind[:,1], grid_ind[:,2])), :]
    processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair)

    data_tuple = (voxel_position,processed_label)

    # center data on each voxel for PTnet
    voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound
    return_xyz = xyz - voxel_centers
    return_xyz = np.concatenate((return_xyz, xyz),axis=1)

    return_fea = return_xyz
    data_tuple += (grid_ind, labels, return_fea)
    return data_tuple


def spherical_dataset(data, config):
    fixed_volume_space = config['fixed_volume_space']
    max_volume_space = config['max_volume_space'],
    min_volume_space = config['min_volume_space'],
    ignore_label = config["ignore_label"],
    grid_size = np.asarray(config['grid_size'])

    xyz, labels = data

    # convert coordinate into polar coordinates
    xyz_pol = cart2polar(xyz)
    
    max_bound_r = np.percentile(xyz_pol[:,0], 100, axis=0)
    min_bound_r = np.percentile(xyz_pol[:,0], 0, axis=0)
    max_bound = np.max(xyz_pol[:,1:], axis=0)
    min_bound = np.min(xyz_pol[:,1:], axis=0)
    max_bound = np.concatenate(([max_bound_r], max_bound))
    min_bound = np.concatenate(([min_bound_r], min_bound))
    if fixed_volume_space:
        max_bound = np.asarray(max_volume_space)[0]
        min_bound = np.asarray(min_volume_space)[0]
    
    # get grid index
    crop_range = max_bound - min_bound
    cur_grid_size = grid_size
    intervals = crop_range/(cur_grid_size-1)
    if (intervals == 0).any(): print("Zero interval!")
    grid_ind = (np.floor((np.clip(xyz_pol, min_bound, max_bound)-min_bound)/intervals)).astype(np.int)

    # process voxel position
    voxel_position = np.zeros(grid_size, dtype=np.float32)
    dim_array = np.ones(len(grid_size)+1, int)
    dim_array[0] = -1 
    voxel_position = np.indices(grid_size)*intervals.reshape(dim_array) + min_bound.reshape(dim_array)
    # voxel_position = polar2cat(voxel_position)
    
    # process labels
    processed_label = np.ones(grid_size, dtype=np.uint8) * np.uint8(ignore_label)
    label_voxel_pair = np.concatenate([grid_ind,labels], axis=1)
    label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:,0], grid_ind[:,1], grid_ind[:,2])), :]
    processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair)
    # data_tuple = (voxel_position,processed_label)

    # prepare visiblity feature
    # find max distance index in each angle,height pair
    valid_label = np.zeros_like(processed_label, dtype=bool)
    valid_label[grid_ind[:, 0], grid_ind[:,1], grid_ind[:,2]] = True
    valid_label = valid_label[::-1]
    max_distance_index = np.argmax(valid_label,axis=0)
    max_distance = max_bound[0]-intervals[0]*(max_distance_index)
    distance_feature = np.expand_dims(max_distance, axis=2)-np.transpose(voxel_position[0],(1,2,0))
    distance_feature = np.transpose(distance_feature, (1,2,0))
    # convert to boolean feature
    distance_feature = (distance_feature > 0) * -1.
    distance_feature[grid_ind[:,2], grid_ind[:,0], grid_ind[:,1]] = 1.
    data_tuple = (distance_feature, processed_label)

    # center data on each voxel for PTnet
    voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound
    return_xyz = xyz_pol - voxel_centers
    return_xyz = np.concatenate((return_xyz, xyz_pol, xyz[:, :2]), axis=1)

    return_fea = return_xyz # [3+3+2] = [[relative rho, relative phi, relative z], [rho, phi, z], [x, y]]
    data_tuple += (grid_ind, labels, return_fea)
    return data_tuple
