
#include "defines.hpp"
#include "tf_utils.hpp"
#include "tf_gpu_device.hpp"
#include "compute_keys.cuh"
#include "find_ranges_grid_ds.cuh"
#include "knn.cuh"

/**
 *  Declaration of the tensorflow operations.
 */
REGISTER_OP("Knn")
    .Input("samples: float32")
    .Input("samples_batch_ids: int32")
    .Input("points: float32")
    .Input("point_keys: int64")
    .Input("grid_acc_ds: int32")
    .Input("num_cells: int32")
    .Input("scaled_aabb_min: float32")
    .Input("inv_cell_sizes: float32")
    .Input("inv_radii: float32")
    .Output("knn_indices: int32")
    .Attr("knn: int")
    .SetShapeFn([](shape_inference::InferenceContext* pIC) {
        int knn;
        TF_RETURN_IF_ERROR(pIC->GetAttr("knn", &knn));
        shape_inference::ShapeHandle outputDims = 
            pIC->MakeShape({pIC->Dim(pIC->input(1), 0), abs(knn)});
        pIC->set_output(0, outputDims);
        return Status::OK();
    });

namespace mccnn{

    /**
     *  Operation to compute the knn from a regular grid.
     */
    class KnnOp: public OpKernel{
        
        public:
        
            /**
             *  Constructor.
             *  @param  pContext    Constructor context of the operation.
             */
            explicit KnnOp(
                OpKernelConstruction* pContext)
                :OpKernel(pContext){
                
                OP_REQUIRES_OK(pContext, pContext->GetAttr("knn", &knn_));
                OP_REQUIRES(pContext, knn_ != 0 && abs(knn_) < 257, 
                    errors::InvalidArgument("KnnOp expects a non zero number " 
                        "of points and smaller than 257."));  
            }
        
            /**
             *  Method to compute the operation.
             *  @param  pContext    Context of the operation.
             */
            void Compute(OpKernelContext * pContext) override{

                //Get the input tensors.
                const Tensor& inSamples = pContext->input(0); 
                const Tensor& inSamplesBatchIds = pContext->input(1); 
                const Tensor& inPts = pContext->input(2); 
                const Tensor& inPtsKeys = pContext->input(3); 
                const Tensor& inGridDs = pContext->input(4); 
                const Tensor& inNumCells = pContext->input(5); 
                const Tensor& inSAABBMin = pContext->input(6); 
                const Tensor& inInvCellSizes = pContext->input(7); 
                const Tensor& inInvRadii = pContext->input(8); 

                //Get variables from tensors.
                unsigned int numPts = inPts.shape().dim_size(0);
                unsigned int numSamples = inSamples.shape().dim_size(0);
                unsigned int numDimensions = inPts.shape().dim_size(1);

                //Get the pointers to GPU data from the tensors.
                const float* samplesGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<float>(inSamples);
                const int* samplesBatchIdsGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<int>(inSamplesBatchIds);
                const float* ptsGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<float>(inPts);
                const mccnn::int64_m* ptsKeysGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<mccnn::int64_m>(inPtsKeys);
                const int* gridDsGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<int>(inGridDs);
                const int* numCellsGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<int>(inNumCells);
                const float* sAABBMinGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<float>(inSAABBMin);
                const float* invCellSizeGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<float>(inInvCellSizes);
                const float* invRadiiGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<float>(inInvRadii);
                
                //Check for the correctness of the input.
                OP_REQUIRES(pContext, numDimensions >= MIN_DIMENSIONS 
                    && numDimensions <= MAX_DIMENSIONS, 
                    errors::InvalidArgument("KnnOp expects a valid number of dimension"));
                OP_REQUIRES(pContext, inSamples.shape().dim_size(1) == numDimensions, 
                    errors::InvalidArgument("KnnOp expects a number of dimensions in"
                    " inSamples equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inSamplesBatchIds.shape().dim_size(0) == numSamples, 
                    errors::InvalidArgument("KnnOp expects the same number of batch"
                    " ids as samples."));
                OP_REQUIRES(pContext, inPts.shape().dim_size(1) == numDimensions, 
                    errors::InvalidArgument("KnnOp expects a number of dimensions in"
                    " inPts equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inPtsKeys.shape().dim_size(0) == numPts, 
                    errors::InvalidArgument("KnnOp expects the same number of keys"
                    " as points."));
                OP_REQUIRES(pContext, inNumCells.shape().dim_size(0) == numDimensions, 
                    errors::InvalidArgument("KnnOp expects a number of dimensions in"
                    " inNumCells equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inSAABBMin.shape().dim_size(1) == numDimensions, 
                    errors::InvalidArgument("KnnOp expects a number of dimensions in"
                    " inSAABBMin equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inInvCellSizes.shape().dim_size(0) == numDimensions, 
                    errors::InvalidArgument("KnnOp expects a number of dimensions in"
                    " inInvCellSizes equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inInvRadii.shape().dim_size(0) == numDimensions, 
                    errors::InvalidArgument("KnnOp expects a number of dimensions in"
                    " inInvRadii equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inGridDs.dims() == 4 &&
                    inGridDs.shape().dim_size(3) == 2, 
                    errors::InvalidArgument("KnnOp expects a grid acceleration data"
                    " structure with the right format (B, X, Y, 2)."));

                //Get the gpu device.
                std::unique_ptr<mccnn::IGPUDevice> gpuDevice = make_unique<mccnn::TFGPUDevice>(pContext);

                //Create the output tensor.
                int* outputGPUPtr = nullptr;
                TensorShape outShape = TensorShape{numSamples, abs(knn_)};
                OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                    (0, pContext, outShape, &outputGPUPtr));

                //Compute the number of offsets to used in the search.
                std::vector<int> combOffsets;
                unsigned int numOffsets = mccnn::computeTotalNumOffsets(
                    numDimensions, 1, combOffsets);

                //Create the temporal tensors.
                mccnn::int64_m* tmpGPUPtr1 = gpuDevice->getInt64TmpGPUBuffer(numSamples);
                int* tmpGPUPtr2 = gpuDevice->getIntTmpGPUBuffer(numSamples*numOffsets*2);
                int* tmpGPUPtr3 = gpuDevice->getIntTmpGPUBuffer(numOffsets*numDimensions);

                //Compute the keys of the samples.
                DIMENSION_SWITCH_CALL(numDimensions, mccnn::compute_keys_gpu,
                    gpuDevice, numSamples, samplesGPUPtr, samplesBatchIdsGPUPtr, 
                    sAABBMinGPUPtr, numCellsGPUPtr, invCellSizeGPUPtr, tmpGPUPtr1);

                //Copy to device the offsets.
                gpuDevice->memcpy_host_to_device(tmpGPUPtr3, 
                    &combOffsets[0], sizeof(int)*numDimensions*numOffsets);
                gpuDevice->check_error(__FILE__, __LINE__);

                //Find ranges of points to check in the data structure.
                DIMENSION_SWITCH_CALL(numDimensions, mccnn::find_ranges_grid_ds_gpu,
                    gpuDevice, numSamples, numPts, 1, numOffsets, 
                    tmpGPUPtr3, tmpGPUPtr1, ptsKeysGPUPtr, gridDsGPUPtr, 
                    numCellsGPUPtr, tmpGPUPtr2);

                //Compute the knns.
                DIMENSION_SWITCH_CALL(numDimensions, mccnn::compute_knn_gpu,
                    gpuDevice, knn_, numOffsets, numSamples, ptsGPUPtr, samplesGPUPtr,
                    invRadiiGPUPtr, tmpGPUPtr2, outputGPUPtr);
            }

        private:

            /**Number of knn to select.*/
            int     knn_;
    };
}
            
REGISTER_KERNEL_BUILDER(Name("Knn").Device(DEVICE_GPU), mccnn::KnnOp);