#include <math.h>
#include <stdio.h>

#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG


__device__ inline void lidar_to_local_coords(float shift_x, float shift_y, float rot_angle, float &local_x, float &local_y){
    float cosa = cos(-rot_angle), sina = sin(-rot_angle);
    local_x = shift_x * cosa + shift_y * (-sina);
    local_y = shift_x * sina + shift_y * cosa;
}


__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d, float &local_x, float &local_y){
    // param pt: (x, y, z)
    // param box3d: [x, y, z, dx, dy, dz, heading] (x, y, z) is the box center

    const float MARGIN = 1e-5;
    float x = pt[0], y = pt[1], z = pt[2];
    float cx = box3d[0], cy = box3d[1], cz = box3d[2];
    float dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];

    if (fabsf(z - cz) > dz / 2.0) return 0;
    lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
    float in_flag = (fabs(local_x) < dx / 2.0 + MARGIN) & (fabs(local_y) < dy / 2.0 + MARGIN);
    return in_flag;
}


__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, int out_x, int out_y, int out_z,
    const float *rois, const float *pts, int *pts_mask){
    // params rois: [x, y, z, dx, dy, dz, heading] (x, y, z) is the box center
    // params pts: (npoints, 3) [x, y, z]
    // params pts_mask: (N, npoints): -1 means point doesnot in this box, otherwise: encode (x_idxs, y_idxs, z_idxs) by binary bit
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    int box_idx = blockIdx.y;
    if (pt_idx >= pts_num || box_idx >= boxes_num) return;

    pts += pt_idx * 3;
    rois += box_idx * 7;
    pts_mask += box_idx * pts_num + pt_idx;

    float local_x = 0, local_y = 0;
    int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);

    pts_mask[0] = -1;
    if (cur_in_flag > 0){
        float local_z = pts[2] - rois[2];
        float dx = rois[3], dy = rois[4], dz = rois[5];

        float x_res = dx / out_x;
        float y_res = dy / out_y;
        float z_res = dz / out_z;

        unsigned int x_idx = int((local_x + dx / 2) / x_res);
        unsigned int y_idx = int((local_y + dy / 2) / y_res);
        unsigned int z_idx = int((local_z + dz / 2) / z_res);

        x_idx = min(max(x_idx, 0), out_x - 1);
        y_idx = min(max(y_idx, 0), out_y - 1);
        z_idx = min(max(z_idx, 0), out_z - 1);

        unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
        pts_mask[0] = idx_encoding;
    }
}


__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num, int max_pts_each_voxel,
    int out_x, int out_y, int out_z, const int *pts_mask, int *pts_idx_of_voxels){
    // params pts_mask: (N, npoints)  0 or 1
    // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)

    int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (box_idx >= boxes_num) return;

    int max_num_pts = max_pts_each_voxel - 1;  // index 0 is the counter
    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;

    for (int k = 0; k < pts_num; k++){
        if (pts_mask[box_idx * pts_num + k] != -1){
            unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
            unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
            unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
            unsigned int z_idx = idx_encoding & 0xFF;
            unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel + y_idx * out_z * max_pts_each_voxel + z_idx * max_pts_each_voxel;
            unsigned int cnt = pts_idx_of_voxels[base_offset];
            if (cnt < max_num_pts){
                pts_idx_of_voxels[base_offset + cnt + 1] = k;
                pts_idx_of_voxels[base_offset]++;
            }
#ifdef DEBUG
        printf("collect: pts_%d, idx(%d, %d, %d), idx_encoding=%x\n",
            k, x_idx, y_idx, z_idx, idx_encoding);
#endif

        }
    }
}


__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
    int out_y, int out_z, const float *pts_feature, const int *pts_idx_of_voxels, float *pooled_features, int *argmax){
    // params pts_feature: (npoints, C)
    // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), index 0 is the counter
    // params pooled_features: (N, out_x, out_y, out_z, C)
    // params argmax: (N, out_x, out_y, out_z, C)

    int box_idx = blockIdx.z;
    int channel_idx = blockIdx.y;
    int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

    int x_idx = voxel_idx_flat / (out_y * out_z);
    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
    int z_idx = voxel_idx_flat % out_z;
    if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return;

#ifdef DEBUG
    printf("src pts_idx_of_voxels: (%p, ), argmax: %p\n", pts_idx_of_voxels, argmax);
#endif

    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + offset_base * max_pts_each_voxel;
    pooled_features += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx;
    argmax += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx;

    int argmax_idx = -1;
    float max_val = -1e50;

    int total_pts = pts_idx_of_voxels[0];

    for (int k = 1; k <= total_pts; k++){
        if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val){
            max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
            argmax_idx = pts_idx_of_voxels[k];
        }
    }

    if (argmax_idx != -1){
        pooled_features[0] = max_val;
    }
    argmax[0] = argmax_idx;

#ifdef DEBUG
    printf("channel_%d idx(%d, %d, %d), argmax_idx=(%d, %.3f), total=%d, after pts_idx: %p, argmax: (%p, %d)\n",
        channel_idx, x_idx, y_idx, z_idx, argmax_idx, max_val, total_pts, pts_idx_of_voxels, argmax, argmax_idx);
#endif
}


__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
    int out_y, int out_z, const float *pts_feature, const int *pts_idx_of_voxels, float *pooled_features){
    // params pts_feature: (npoints, C)
    // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), index 0 is the counter
    // params pooled_features: (N, out_x, out_y, out_z, C)
    // params argmax: (N, out_x, out_y, out_z, C)

    int box_idx = blockIdx.z;
    int channel_idx = blockIdx.y;
    int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

    int x_idx = voxel_idx_flat / (out_y * out_z);
    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
    int z_idx = voxel_idx_flat % out_z;
    if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return;

    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + offset_base * max_pts_each_voxel;
    pooled_features += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx;

    float sum_val = 0;
    int total_pts = pts_idx_of_voxels[0];

    for (int k = 1; k <= total_pts; k++){
        sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
    }

    if (total_pts > 0){
        pooled_features[0] = sum_val / total_pts;
    }
}


void roiaware_pool3d_launcher(int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, int out_y, int out_z,
    const float *rois, const float *pts, const float *pts_feature, int *argmax, int *pts_idx_of_voxels, float *pooled_features, int pool_method){
    // params rois: (N, 7) [x, y, z, dx, dy, dz, heading] (x, y, z) is the box center
    // params pts: (npoints, 3) [x, y, z]
    // params pts_feature: (npoints, C)
    // params argmax: (N, out_x, out_y, out_z, C)
    // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
    // params pooled_features: (N, out_x, out_y, out_z, C)
    // params pool_method: 0: max_pool 1: avg_pool

    int *pts_mask = NULL;
    cudaMalloc(&pts_mask, boxes_num * pts_num * sizeof(int));  // (N, M)
    cudaMemset(pts_mask, -1, boxes_num * pts_num * sizeof(int));

    dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num);
    dim3 threads(THREADS_PER_BLOCK);
    generate_pts_mask_for_box3d<<<blocks_mask, threads>>>(boxes_num, pts_num, out_x, out_y, out_z, rois, pts, pts_mask);


    dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK));
    collect_inside_pts_for_box3d<<<blocks_collect, threads>>>(boxes_num, pts_num, max_pts_each_voxel,
        out_x, out_y, out_z, pts_mask, pts_idx_of_voxels);

    dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, boxes_num);
    if (pool_method == 0){
        roiaware_maxpool3d<<<blocks_pool, threads>>>(boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
            pts_feature, pts_idx_of_voxels, pooled_features, argmax);
    }
    else if (pool_method == 1){
        roiaware_avgpool3d<<<blocks_pool, threads>>>(boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
            pts_feature, pts_idx_of_voxels, pooled_features);
    }


    cudaFree(pts_mask);

#ifdef DEBUG
    cudaDeviceSynchronize();  // for using printf in kernel function
#endif
}


__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels, int out_x, int out_y, int out_z,
    const int *argmax, const float *grad_out, float *grad_in){
    // params argmax: (N, out_x, out_y, out_z, C)
    // params grad_out: (N, out_x, out_y, out_z, C)
    // params grad_in: (npoints, C), return value

    int box_idx = blockIdx.z;
    int channel_idx = blockIdx.y;
    int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

    int x_idx = voxel_idx_flat / (out_y * out_z);
    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
    int z_idx = voxel_idx_flat % out_z;
    if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return;

    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
    argmax += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx;
    grad_out += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx;

    if (argmax[0] == -1) return;

    atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
}


__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels, int out_x, int out_y, int out_z,
    int max_pts_each_voxel, const int *pts_idx_of_voxels, const float *grad_out, float *grad_in){
    // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
    // params grad_out: (N, out_x, out_y, out_z, C)
    // params grad_in: (npoints, C), return value

    int box_idx = blockIdx.z;
    int channel_idx = blockIdx.y;
    int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;

    int x_idx = voxel_idx_flat / (out_y * out_z);
    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
    int z_idx = voxel_idx_flat % out_z;
    if (box_idx >= boxes_num || channel_idx >= channels|| x_idx >= out_x || y_idx >= out_y || z_idx >= out_z) return;

    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + offset_base * max_pts_each_voxel;
    grad_out += box_idx * out_x * out_y * out_z * channels + offset_base * channels + channel_idx;


    int total_pts = pts_idx_of_voxels[0];
    float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
    for (int k = 1; k <= total_pts; k++){
        atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx, grad_out[0] * cur_grad);
    }
}


void roiaware_pool3d_backward_launcher(int boxes_num, int out_x, int out_y, int out_z, int channels, int max_pts_each_voxel,
    const int *pts_idx_of_voxels, const int *argmax, const float *grad_out, float *grad_in, int pool_method){
    // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
    // params argmax: (N, out_x, out_y, out_z, C)
    // params grad_out: (N, out_x, out_y, out_z, C)
    // params grad_in: (npoints, C), return value
    // params pool_method: 0: max_pool, 1: avg_pool

    dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, boxes_num);
    dim3 threads(THREADS_PER_BLOCK);
    if (pool_method == 0){
        roiaware_maxpool3d_backward<<<blocks, threads>>>(
            boxes_num, channels, out_x, out_y, out_z, argmax, grad_out, grad_in
        );
    }
    else if (pool_method == 1){
        roiaware_avgpool3d_backward<<<blocks, threads>>>(
            boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel, pts_idx_of_voxels, grad_out, grad_in
        );
    }

}


__global__ void points_in_boxes_kernel(int batch_size, int boxes_num, int pts_num, const float *boxes,
    const float *pts, int *box_idx_of_points){
    // params boxes: (B, N, 7) [x, y, z, dx, dy, dz, heading] (x, y, z) is the box center
    // params pts: (B, npoints, 3) [x, y, z] in LiDAR coordinate
    // params boxes_idx_of_points: (B, npoints), default -1

    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= batch_size || pt_idx >= pts_num) return;

    boxes += bs_idx * boxes_num * 7;
    pts += bs_idx * pts_num * 3 + pt_idx * 3;
    box_idx_of_points += bs_idx * pts_num + pt_idx;

    float local_x = 0, local_y = 0;
    int cur_in_flag = 0;
    for (int k = 0; k < boxes_num; k++){
        cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
        if (cur_in_flag){
            box_idx_of_points[0] = k;
            break;
        }
    }
}


void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num, const float *boxes,
    const float *pts, int *box_idx_of_points){
    // params boxes: (B, N, 7) [x, y, z, dx, dy, dz, heading] (x, y, z) is the box center
    // params pts: (B, npoints, 3) [x, y, z]
    // params boxes_idx_of_points: (B, npoints), default -1
    cudaError_t err;

    dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
    dim3 threads(THREADS_PER_BLOCK);
    points_in_boxes_kernel<<<blocks, threads>>>(batch_size, boxes_num, pts_num, boxes, pts, box_idx_of_points);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }

#ifdef DEBUG
    cudaDeviceSynchronize();  // for using printf in kernel function
#endif
}
