
#include "defines.hpp"
#include "tf_utils.hpp"
#include "tf_gpu_device.hpp"

#include "count_unique_keys.cuh"
#include "store_unique_keys.cuh"
#include "pooling_avg.cuh"
#include "count_pooling_pd.cuh"
#include "store_pooled_pts.cuh"

/**
 *  Declaration of the tensorflow operation.
 */
REGISTER_OP("Pooling")
    .Input("points: float32")
    .Input("batch_ids: int32")
    .Input("point_keys: int64")
    .Input("num_cells: int32")
    .Input("neighbors: int32")
    .Input("sample_neigh_indices: int32")
    .Output("pool_pts: float32")
    .Output("pool_batch_ids: int32")
    .Output("pool_indices: int32")
    .Attr("mode: int")
    .SetShapeFn([](shape_inference::InferenceContext* pIC) {
        shape_inference::ShapeHandle outputDims1 = 
            pIC->MakeShape({-1, pIC->Dim(pIC->input(0), 1)});
        shape_inference::ShapeHandle outputDims2 = 
            pIC->MakeShape({-1});
        pIC->set_output(0, outputDims1);
        pIC->set_output(1, outputDims2);
        pIC->set_output(2, outputDims2);
        return Status::OK();
    });

namespace mccnn{

    /**
     *  Operation to perform a pooling operation on a regular grid.
     */
    class PoolingOp: public OpKernel{
        
        public:
        
            /**
             *  Constructor.
             *  @param  pContext    Constructor context of the operation.
             */
            explicit PoolingOp(
                OpKernelConstruction* pContext)
                :OpKernel(pContext){

                OP_REQUIRES_OK(pContext, pContext->GetAttr("mode", &mode_));
                OP_REQUIRES(pContext, mode_ >= 0 && mode_ < 2, 
                    errors::InvalidArgument("PoolingOp requires a valid pooling mode."));
            }
        
            /**
             *  Method to compute the operation.
             *  @param  pContext    Context of the operation.
             */
            void Compute(OpKernelContext * pContext) override{

                //Get the input tensors.
                const Tensor& inPts = pContext->input(0); 
                const Tensor& inBatchIds = pContext->input(1); 
                const Tensor& inPtKeys = pContext->input(2); 
                const Tensor& inNumCells = pContext->input(3);
                const Tensor& inNeighbors = pContext->input(4); 
                const Tensor& inStartNeighIds = pContext->input(5); 

                //Get variables from tensors.
                unsigned int numPts = inPts.shape().dim_size(0);
                unsigned int numDimensions = inPts.shape().dim_size(1);
                 unsigned int numNeighbors = inNeighbors.shape().dim_size(0);

                //Get the pointers to GPU data from the tensors.
                const float* ptsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inPts);
                const int* batchIdsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inBatchIds);
                const mccnn::int64_m* ptKeysGPUPtr = mccnn::tensorflow_utils::
                    get_const_tensor_pointer<mccnn::int64_m>(inPtKeys);
                const int* numCellsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inNumCells);
                const int* inNeighborsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inNeighbors);
                const int* inStartNeighIdsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inStartNeighIds);
                
                
                //Check for the correctness of the input.
                OP_REQUIRES(pContext, numDimensions >= MIN_DIMENSIONS 
                    && numDimensions <= MAX_DIMENSIONS, 
                    errors::InvalidArgument("PoolingOp expects a valid number of dimension"));
                OP_REQUIRES(pContext, inBatchIds.shape().dim_size(0) == numPts, 
                    errors::InvalidArgument("PoolingOp expects the same number of keys"
                    " as number of points."));
                OP_REQUIRES(pContext, inPtKeys.shape().dim_size(0) == numPts, 
                    errors::InvalidArgument("PoolingOp expects the same number of keys"
                    " as number of points."));
                OP_REQUIRES(pContext, inNumCells.shape().dim_size(0) == numDimensions, 
                    errors::InvalidArgument("PoolingOp expects a number of dimensions in"
                    " inNumCells equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inNeighbors.dims() == 2 && 
                    inNeighbors.shape().dim_size(1) == 2, 
                    errors::InvalidArgument("PoolingOp expects a neighbor tensor with 2 dimensions "
                    "and 2 indices per neighbor."));
                OP_REQUIRES(pContext, inStartNeighIds.shape().dim_size(0) == numPts, 
                    errors::InvalidArgument("PoolingOp expects the same number of points "
                    "in inStartNeighIds as in the points tensor."));
                

                //Get the gpu device.
                std::unique_ptr<mccnn::IGPUDevice> gpuDevice = make_unique<mccnn::TFGPUDevice>(pContext);

                //Count the number of unique keys.
                unsigned int numKeys = mccnn::count_unique_keys_gpu(gpuDevice, numPts, ptKeysGPUPtr);
                
                //Declare the temporal buffers.
                int* tmpGPUPtr1 = gpuDevice->getIntTmpGPUBuffer(numKeys);

                //Store the first indices of each key.
                mccnn::store_unique_keys_gpu(gpuDevice, numPts, ptKeysGPUPtr, tmpGPUPtr1);

                //Poisson disk sampling.
                if(mode_ == 0){
                    
                    //Declare temporal buffers.
                    int* tmpGPUPtr2 = gpuDevice->getIntTmpGPUBuffer(numPts);

                    //Count the number of pooled points.
                    int numPooledPts;
                    DIMENSION_SWITCH_CALL(numDimensions, mccnn::count_pooling_pd_gpu, 
                        gpuDevice, numPts, numKeys, tmpGPUPtr1, ptKeysGPUPtr,
                        ptsGPUPtr, inNeighborsGPUPtr, inStartNeighIdsGPUPtr,
                        numCellsGPUPtr, numPooledPts, tmpGPUPtr2);

                    //Create the output tensors.
                    float* outPtsGPUPtr = nullptr;
                    int* outBatchIdsGPUPtr = nullptr;
                    int* outIndicesGPUPtr = nullptr;
                    TensorShape outShape1 = TensorShape{numPooledPts, numDimensions};
                    TensorShape outShape2 = TensorShape{numPooledPts};
                    TensorShape outShape3 = TensorShape{numPooledPts};
                    OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<float>
                        (0, pContext, outShape1, &outPtsGPUPtr));
                    OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                        (1, pContext, outShape2, &outBatchIdsGPUPtr));
                    OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                        (2, pContext, outShape3, &outIndicesGPUPtr));

                    //Store the pooled points.
                    DIMENSION_SWITCH_CALL(numDimensions, mccnn::store_pooled_pts_gpu,
                        gpuDevice, numPts, numPooledPts, ptsGPUPtr,
                        batchIdsGPUPtr, tmpGPUPtr2, outPtsGPUPtr, 
                        outBatchIdsGPUPtr, outIndicesGPUPtr);

                //Cell average.
                }else{
                    //Create the output tensors.
                    float* outPtsGPUPtr = nullptr;
                    int* outBatchIdsGPUPtr = nullptr;
                    int* outIndicesGPUPtr = nullptr;
                    TensorShape outShape1 = TensorShape{numKeys, numDimensions};
                    TensorShape outShape2 = TensorShape{numKeys};
                    TensorShape outShape3 = TensorShape{1};
                    OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<float>
                        (0, pContext, outShape1, &outPtsGPUPtr));
                    OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                        (1, pContext, outShape2, &outBatchIdsGPUPtr));
                    OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                        (2, pContext, outShape3, &outIndicesGPUPtr));
                    

                    //Select the new point.
                    DIMENSION_SWITCH_CALL(numDimensions, mccnn::pooling_avg_gpu,
                        gpuDevice, numPts, numKeys, ptKeysGPUPtr, ptsGPUPtr, 
                        numCellsGPUPtr, tmpGPUPtr1, outPtsGPUPtr, outBatchIdsGPUPtr);
                }
            }

        private:

            /**Mode used to pool points.*/
            int mode_;
    };
}
            
REGISTER_KERNEL_BUILDER(Name("Pooling").Device(DEVICE_GPU), mccnn::PoolingOp);