
#include "defines.hpp"
#include "math_helper.cuh"
#include "cuda_kernel_utils.cuh"

#include "knn.cuh"

///////////////////////// GPU

/**
 *  GPU kernel to compute the knn on the gpu.
 *  @param  pKnn            Number of neighbors.
 *  @param  pNumSamples     Number of samples.
 *  @param  pPts            Input array with the points.
 *  @param  pSamples        Input array with the samples.
 *  @param  pRanges         Input array with the search ranges.
 *  @param  pInvRadii       Input array with the inverse of
 *      the radii.
 *  @param  pOutKnn         Output array with the knn.
 *  @paramt D               Number of dimensions. 
 */
template<int D>
__global__ void compute_knn_gpu_kernel(
    const int pKnn,
    const unsigned int pNumRanges,
    const unsigned int pNumSamples,
    const mccnn::fpoint<D>* __restrict__ pPts,
    const mccnn::fpoint<D>* __restrict__ pSamples,
    const float* __restrict__ pInvRadii,
    const int2* __restrict__ pRanges,
    int* __restrict__ pOutKnn)
{
    //Declare the shared memory.
    extern __shared__ float sharedMem[];

    //Check if we want to select the closest or the farthests.
    int absKnn = (pKnn<0)?-pKnn:pKnn;
    bool closest = (pKnn > 0);

    //Get the pointers to the different regions of the shared memory.
    float* auxDistances = &sharedMem[0];
    int* localKnnIndexs = (int*)&sharedMem[blockDim.x];
    int* auxKnnIndexs = (int*)&sharedMem[blockDim.x+absKnn];
    int* globalKnnIndexs = (int*)&sharedMem[blockDim.x+absKnn*2];
    float* localKnnDistances = &sharedMem[blockDim.x+absKnn*3];
    float* auxKnnDistances = &sharedMem[blockDim.x+absKnn*4];
    float* globalKnnDistances = &sharedMem[blockDim.x+absKnn*5];

    //Process all the samples.
    for(int curSample = blockIdx.x; 
        curSample < pNumSamples; 
        curSample += gridDim.x)
    {        
        //Initialize shared memory.
        if(threadIdx.x < absKnn){
            globalKnnIndexs[threadIdx.x] = -1;
            globalKnnDistances[threadIdx.x] = (closest)?
                3.402823466e+38f:-3.402823466e+38f;
        }

        //Iterate over the search ranges.
        for(int c = 0; c < pNumRanges; ++c)
        {
            //Get the number of neighbors.
            int2 curRange = pRanges[curSample*pNumRanges + c];
            int numPts = curRange.y - curRange.x;

            //Get the sample position.
            mccnn::fpoint<D> sampleCoords = pSamples[curSample];

            //Iterate over the neighbors.
            int numIters = numPts/blockDim.x;
            numIters += (numPts%blockDim.x != 0)?1:0;
            for(int i = 0; i < numIters; ++i)
            {
                //Compute the index.
                int currentIndex = i*blockDim.x + threadIdx.x;

                //Initialize shared memory.
                auxDistances[threadIdx.x] = (closest)?3.402823466e+38f:
                    -3.402823466e+38f;
                if(threadIdx.x < absKnn){
                    localKnnIndexs[threadIdx.x] = -1;
                    localKnnDistances[threadIdx.x] = (closest)?
                        3.402823466e+38f:-3.402823466e+38f;
                }

                __syncthreads();

                int neighIndex = -1;
                float neighDist = 0.0f;
                if(currentIndex < numPts){
                    //Get the distance of the neighbor.
                    neighIndex = curRange.x+currentIndex;
                    mccnn::fpoint<D> neighPts = pPts[neighIndex];
                    neighDist = mccnn::length((neighPts-sampleCoords)*pInvRadii[0]);

                    //If it is inside the radius.
                    if(neighDist < 1.0){
                        //Store the results in shared memory.
                        auxDistances[threadIdx.x] = neighDist;
                    }else{
                        neighIndex = -1;
                        neighDist = 0.0f;
                    }
                }

                __syncthreads();

                if(neighIndex >= 0){
                    //Select the local knn.
                    int numPtsInFront = 0;
                    for(int localIter = 0; localIter < blockDim.x; ++localIter)
                    {
                        if(localIter != threadIdx.x){
                            if(closest){
                                numPtsInFront += (auxDistances[localIter] < neighDist)?1:0;
                            }else{
                                numPtsInFront += (auxDistances[localIter] > neighDist)?1:0;
                            }
                            if(auxDistances[localIter] == neighDist){
                                numPtsInFront += (localIter < threadIdx.x)?1:0;
                            }
                        }
                    }
                    
                    //Update the local knn list.
                    if(numPtsInFront < absKnn){
                        localKnnIndexs[numPtsInFront] = neighIndex;
                        localKnnDistances[numPtsInFront] = neighDist;
                    }
                }

                __syncthreads();

                //Update the global knn list.
                if(threadIdx.x == 0){
                    int localIter = 0;
                    int globalIter = 0;
                    bool local = false;
                    for(int finalIter = 0; finalIter < absKnn; ++finalIter){
                        if(closest){
                            local = globalKnnDistances[globalIter] > 
                                localKnnDistances[localIter];
                        }else{
                            local = globalKnnDistances[globalIter] < 
                                localKnnDistances[localIter];
                        }

                        if(local){
                            auxKnnIndexs[finalIter] = localKnnIndexs[localIter];
                            auxKnnDistances[finalIter] = localKnnDistances[localIter];
                            localIter++;
                        }else{
                            auxKnnIndexs[finalIter] = globalKnnIndexs[globalIter];
                            auxKnnDistances[finalIter] = globalKnnDistances[globalIter];
                            globalIter++;
                        }
                    }

                    //Save the result into the final buffer.
                    for(int finalIter = 0; finalIter < absKnn; ++finalIter){
                        globalKnnIndexs[finalIter] = auxKnnIndexs[finalIter];
                        globalKnnDistances[finalIter] = auxKnnDistances[finalIter];
                    }
                }

                __syncthreads();
            }
        }    

        //Update the final knn.
        if(threadIdx.x < absKnn)
            pOutKnn[curSample*absKnn + threadIdx.x] = globalKnnIndexs[threadIdx.x];

        __syncthreads();  
    }
}

///////////////////////// CPU

template<int D>
void mccnn::compute_knn_gpu(
    std::unique_ptr<IGPUDevice>& pDevice,
    const int pKnn,
    const unsigned int pNumRanges,
    const unsigned int pNumSamples,
    const float* pInGPUPtrPts,
    const float* pInGPUPtrSamples,
    const float* pInGPUPtrInvRadii,
    const int* pInGPUPtrRanges,
    int* pOutGPUPtrKnn)
{
    //Get the cuda stream.
    auto cudaStream = pDevice->getCUDAStream();

#ifdef DEBUG_INFO
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);
    cudaEventRecord(start, cudaStream);
#endif

    //Get the device properties.
    const GpuDeviceProperties& gpuProps = pDevice->get_device_properties();

    //Calculate the ideal number of blocks for the selected block size.
    unsigned int numMP = gpuProps.numMPs_;
    unsigned int blockSize = std::max(64, abs(pKnn));
    unsigned int numBlocks = pDevice->get_max_active_block_x_sm(
        blockSize, (const void*)compute_knn_gpu_kernel<D>, 
        blockSize*sizeof(float) + (3*abs(pKnn))*(sizeof(int) + sizeof(float)));
    pDevice->check_error(__FILE__, __LINE__);

    //Calculate the total number of blocks to execute.
    unsigned int execBlocks = pNumSamples;
    unsigned int totalNumBlocks = numMP*numBlocks;
    totalNumBlocks = (totalNumBlocks > execBlocks)?execBlocks:totalNumBlocks;

    //Execute the appropriate cuda kernel based on the selected mode.
    compute_knn_gpu_kernel<D><<<totalNumBlocks, blockSize, 
        blockSize*sizeof(float) + (3*abs(pKnn))*(sizeof(int) + sizeof(float)), 
        cudaStream>>>(
        pKnn, pNumRanges, pNumSamples,
        (const mccnn::fpoint<D>*)pInGPUPtrPts,
        (const mccnn::fpoint<D>*)pInGPUPtrSamples,
        pInGPUPtrInvRadii,
        (const int2*)pInGPUPtrRanges,
        pOutGPUPtrKnn);

    pDevice->check_error(__FILE__, __LINE__);

#ifdef DEBUG_INFO
    cudaEventRecord(stop, cudaStream);
    cudaEventSynchronize(stop);
    float milliseconds = 0;
    cudaEventElapsedTime(&milliseconds, start, stop);

    float gpuOccupancy = (float)(numBlocks*blockSize)/(float)gpuProps.maxThreadsXMP_;

    fprintf(stderr, "### COMPUTE KNN ###\n");
    fprintf(stderr, "KNN: %d\n", pKnn);
    fprintf(stderr, "Num samples: %d\n", pNumSamples);
    fprintf(stderr, "Occupancy: %f\n", gpuOccupancy);
    fprintf(stderr, "Execution time: %f\n", milliseconds);
    fprintf(stderr, "\n");
#endif
}

///////////////////////// CPU Template declaration

#define COMPUTE_KNN_TEMP_DECL(Dims)            \
    template void mccnn::compute_knn_gpu<Dims>(\
        std::unique_ptr<IGPUDevice>& pDevice,  \
        const int pKnn,                        \
        const unsigned int pNumRanges,         \
        const unsigned int pNumSamples,        \
        const float* pInGPUPtrPts,             \
        const float* pInGPUPtrSamples,         \
        const float* pInGPUPtrInvRadii,        \
        const int* pInGPUPtrRanges,            \
        int* pOutGPUPtrKnn);

DECLARE_TEMPLATE_DIMS(COMPUTE_KNN_TEMP_DECL)