#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "ball_query_gpu.h"
#include "cuda_utils.h"

#define lson (cur * 2 + 1)
#define rson (cur * 2 + 2)
#define father ((cur - 1) / 2)


__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 
    const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx_cnt, int *__restrict__ idx) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx_cnt: (B, M)
    //      idx: (B, M, nsample)
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    idx += bs_idx * m * nsample + pt_idx * nsample;
    idx_cnt += bs_idx * m + pt_idx;

    float radius2 = radius * radius;
    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    int cnt = 0;
    for (int k = 0; k < n; ++k) {
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 < radius2){
            idx[cnt] = k;
            ++cnt;
            if (cnt >= nsample) break;
        }
    }
    idx_cnt[0] = cnt;
    for (int l = 0; cnt < nsample; ++l, ++cnt) {
        idx[cnt] = idx[l];
    }
}


void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
    const float *new_xyz, const float *xyz, int *idx_cnt, int *idx) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx_cnt: (B, M)
    //      idx: (B, M, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    ball_query_kernel_fast<<<blocks, threads>>>(b, n, m, radius, nsample, new_xyz, xyz, idx_cnt, idx);
    // cudaDeviceSynchronize();  // for using printf in kernel function
    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


__global__ void ball_query_dilated_kernel_fast(int b, int n, int m, float radius_in, float radius_out, int nsample,
    const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx_cnt, int *__restrict__ idx) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx_cnt: (B, M)
    //      idx: (B, M, nsample)
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    idx += bs_idx * m * nsample + pt_idx * nsample;
    idx_cnt += bs_idx * m + pt_idx;

    float radius_in2 = radius_in * radius_in;
    float radius_out2 = radius_out * radius_out;
    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    int cnt = 0;
    for (int k = 0; k < n; ++k) {
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 >= radius_in2 && d2 < radius_out2){
            idx[cnt] = k;
            ++cnt;
            if (cnt >= nsample) break;
        }
    }
    idx_cnt[0] = cnt;
    for (int l = 0; cnt < nsample; ++l, ++cnt) {
        idx[cnt] = idx[l];
    }
}


void ball_query_dilated_kernel_launcher_fast(int b, int n, int m, float radius_in, float radius_out, int nsample, \
    const float *new_xyz, const float *xyz, int *idx_cnt, int *idx) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx_cnt: (B, M)
    //      idx: (B, M, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    ball_query_dilated_kernel_fast<<<blocks, threads>>>(b, n, m, radius_in, radius_out, nsample, new_xyz, xyz, idx_cnt, idx);
    // cudaDeviceSynchronize();  // for using printf in kernel function
    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


// __global__ void ball_query_dilated_kernel_fast(int b, int n, int m, float max_radius, float min_radius, int nsample, 
//     const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
//     // new_xyz: (B, M, 3)
//     // xyz: (B, N, 3)
//     // output:
//     // idx: (B, M, nsample)
//     int bs_idx = blockIdx.y;
//     int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
//     if (bs_idx >= b || pt_idx >= m) return;

//     new_xyz += bs_idx * m * 3 + pt_idx * 3;
//     xyz += bs_idx * n * 3;
//     idx += bs_idx * m * nsample + pt_idx * nsample;

//     float radius1 = max_radius * max_radius;
//     float radius2 = min_radius * min_radius;
//     float new_x = new_xyz[0];
//     float new_y = new_xyz[1];
//     float new_z = new_xyz[2];

//     int cnt = 0;
//     for (int k = 0; k < n; ++k) {
//         float x = xyz[k * 3 + 0];
//         float y = xyz[k * 3 + 1];
//         float z = xyz[k * 3 + 2];
//         float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
//         if (d2 == 0){
//             if (cnt == 0){
//                 for (int l = 0; l < nsample; ++l) {
//                     idx[l] = k;
//                 }
//             }
//             idx[cnt] = k;
//             ++cnt;
//             if (cnt >= nsample) break;
//         }
//         if (d2 >= radius2 && d2 < radius1){
//             if (cnt == 0){
//                 for (int l = 0; l < nsample; ++l) {
//                     idx[l] = k;
//                 }
//             }
//             idx[cnt] = k;
//             ++cnt;
//             if (cnt >= nsample) break;
//         }
//     }
// }


// void ball_query_dilated_kernel_launcher_fast(int b, int n, int m, float max_radius, float min_radius, int nsample, \
//     const float *new_xyz, const float *xyz, int *idx) {
//     // new_xyz: (B, M, 3)
//     // xyz: (B, N, 3)
//     // output:
//     //      idx: (B, M, nsample)

//     cudaError_t err;

//     dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
//     dim3 threads(THREADS_PER_BLOCK);

//     ball_query_dilated_kernel_fast<<<blocks, threads>>>(b, n, m, max_radius, min_radius, nsample, new_xyz, xyz, idx);
//     // cudaDeviceSynchronize();  // for using printf in kernel function
//     err = cudaGetLastError();
//     if (cudaSuccess != err) {
//         fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
//         exit(-1);
//     }
// }