#include "../cuda_utils.h"
#include "sampling_cuda_kernel.h"


__device__ void __update(float *dists, int *dists_i, int idx1, int idx2) {
    const float v1 = dists[idx1], v2 = dists[idx2];
    const int i1 = dists_i[idx1], i2 = dists_i[idx2];
    dists[idx1] = max(v1, v2);
    dists_i[idx1] = v2 > v1 ? i2 : i1;
}

// input xyz: (n, 3), tmp: (b, n_max)
// ouput idx (m)
template <unsigned int block_size>
__global__ void farthest_point_sampling_cuda_kernel(const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx)
{
    __shared__ float dists[block_size];
    __shared__ int dists_i[block_size];

    int bid = blockIdx.x;
    int start_n, end_n, start_m, end_m, old;
    if (bid == 0) {
        start_n = 0;
        end_n = offset[0];
        start_m = 0;
        end_m = new_offset[0];
        old = 0;
    }
    else {
        start_n = offset[bid - 1];
        end_n = offset[bid];
        start_m = new_offset[bid - 1];
        end_m = new_offset[bid];
        old = offset[bid - 1];
    }

    const int stride = block_size;
    int tid = threadIdx.x;
    if (tid == 0) idx[start_m] = start_n;

    __syncthreads();
    for (int j = start_m + 1; j < end_m; j++)
    {
        int besti = start_n;
        float best = -1;
        float x1 = xyz[old * 3 + 0];
        float y1 = xyz[old * 3 + 1];
        float z1 = xyz[old * 3 + 2];
        for (int k = start_n + tid; k < end_n; k += stride)
        {
            float x2 = xyz[k * 3 + 0];
            float y2 = xyz[k * 3 + 1];
            float z2 = xyz[k * 3 + 2];
            float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
            float d2 = min(d, tmp[k]);
            tmp[k] = d2;
            besti = d2 > best ? k : besti;
            best = d2 > best ? d2 : best;
        }
        dists[tid] = best;
        dists_i[tid] = besti;
        __syncthreads();

        if (block_size >= 1024) {
            if (tid < 512) {
            __update(dists, dists_i, tid, tid + 512);
            }
            __syncthreads();
        }
        if (block_size >= 512) {
            if (tid < 256) {
            __update(dists, dists_i, tid, tid + 256);
            }
            __syncthreads();
        }
        if (block_size >= 256) {
            if (tid < 128) {
            __update(dists, dists_i, tid, tid + 128);
            }
            __syncthreads();
        }
        if (block_size >= 128) {
            if (tid < 64) {
            __update(dists, dists_i, tid, tid + 64);
            }
            __syncthreads();
        }
        if (block_size >= 64) {
            if (tid < 32) {
            __update(dists, dists_i, tid, tid + 32);
            }
            __syncthreads();
        }
        if (block_size >= 32) {
            if (tid < 16) {
            __update(dists, dists_i, tid, tid + 16);
            }
            __syncthreads();
        }
        if (block_size >= 16) {
            if (tid < 8) {
            __update(dists, dists_i, tid, tid + 8);
            }
            __syncthreads();
        }
        if (block_size >= 8) {
            if (tid < 4) {
            __update(dists, dists_i, tid, tid + 4);
            }
            __syncthreads();
        }
        if (block_size >= 4) {
            if (tid < 2) {
            __update(dists, dists_i, tid, tid + 2);
            }
            __syncthreads();
        }
        if (block_size >= 2) {
            if (tid < 1) {
            __update(dists, dists_i, tid, tid + 1);
            }
            __syncthreads();
        }

        old = dists_i[0];
        if (tid == 0)
            idx[j] = old;
    }
}

void farthest_point_sampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx)
{   
	unsigned int n_threads = opt_n_threads(n);
	switch (n_threads) {
        case 1024:
            farthest_point_sampling_cuda_kernel<1024><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 512:
            farthest_point_sampling_cuda_kernel<512><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 256:
            farthest_point_sampling_cuda_kernel<256><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 128:
            farthest_point_sampling_cuda_kernel<128><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 64:
            farthest_point_sampling_cuda_kernel<64><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 32:
            farthest_point_sampling_cuda_kernel<32><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 16:
            farthest_point_sampling_cuda_kernel<16><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 8:
            farthest_point_sampling_cuda_kernel<8><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 4:
            farthest_point_sampling_cuda_kernel<4><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 2:
            farthest_point_sampling_cuda_kernel<2><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        case 1:
            farthest_point_sampling_cuda_kernel<1><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
            break;
        default:
            farthest_point_sampling_cuda_kernel<512><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
    }
}
