#include "defines.hpp"
#include "tf_utils.hpp"
#include "tf_gpu_device.hpp"
#include "compute_smooth_weights.cuh"

/**
 *  Declaration of the tensorflow operations.
 */
REGISTER_OP("ComputeSmoothW")
    .Input("points: float32")
    .Input("samples: float32")
    .Input("neighbors: int32")
    .Input("sample_neigh_indices: int32")
    .Input("inv_radii: float32")
    .Output("sweights: float32")
    .SetShapeFn([](shape_inference::InferenceContext* pIC) {
        shape_inference::ShapeHandle outputDims = 
            pIC->MakeShape({pIC->Dim(pIC->input(2), 0)});
        pIC->set_output(0, outputDims);
        return Status::OK();
    });

REGISTER_OP("ComputeSmoothWWithPtGrads")
    .Input("points: float32")
    .Input("samples: float32")
    .Input("neighbors: int32")
    .Input("sample_neigh_indices: int32")
    .Input("inv_radii: float32")
    .Output("sweights: float32")
    .SetShapeFn([](shape_inference::InferenceContext* pIC) {
        shape_inference::ShapeHandle outputDims = 
            pIC->MakeShape({pIC->Dim(pIC->input(2), 0)});
        pIC->set_output(0, outputDims);
        return Status::OK();
    });

REGISTER_OP("ComputeSmoothWPtGrads")
    .Input("points: float32")
    .Input("samples: float32")
    .Input("neighbors: int32")
    .Input("sample_neigh_indices: int32")
    .Input("inv_radii: float32")
    .Input("sweights_grads: float32")
    .Output("pt_grads: float32")
    .Output("sample_grads: float32")
    .SetShapeFn([](shape_inference::InferenceContext* pIC) {
        shape_inference::ShapeHandle outputDims = 
            pIC->MakeShape({
            pIC->Dim(pIC->input(0), 0),
            pIC->Dim(pIC->input(0), 1)});
        shape_inference::ShapeHandle outputDims2 = 
            pIC->MakeShape({
            pIC->Dim(pIC->input(1), 0),
            pIC->Dim(pIC->input(1), 1)});
        pIC->set_output(0, outputDims);
        pIC->set_output(1, outputDims2);
        return Status::OK();
    });

namespace mccnn{

    /**
     *  Operation to compute the smooth weights of each neighbor.
     */
    class ComputeSmoothWOp: public OpKernel{
        
        public:
        
            /**
             *  Constructor.
             *  @param  pContext    Constructor context of the operation.
             */
            explicit ComputeSmoothWOp(
                OpKernelConstruction* pContext)
                :OpKernel(pContext){
            }
        
            /**
             *  Method to compute the operation.
             *  @param  pContext    Context of the operation.
             */
            void Compute(OpKernelContext * pContext) override{

                //Get the input tensors.
                const Tensor& inPts = pContext->input(0); 
                const Tensor& inSamples = pContext->input(1); 
                const Tensor& inNeighbors = pContext->input(2); 
                const Tensor& inSampleNeighIndices = pContext->input(3);
                const Tensor& inInvRadii = pContext->input(4);

                //Get variables from tensors.
                unsigned int numSamples = inSamples.shape().dim_size(0);
                unsigned int numNeighbors = inNeighbors.shape().dim_size(0);
                unsigned int numDimensions = inPts.shape().dim_size(1);

                //Get the pointers to GPU data from the tensors.
                const float* ptsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inPts);
                const float* samplesGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inSamples);
                const int* neighborsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inNeighbors);
                const int* sampleNeighIGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inSampleNeighIndices);
                const float* inInvRadiiGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inInvRadii);

                //Check for the correctness of the input.
                OP_REQUIRES(pContext, numDimensions >= MIN_DIMENSIONS 
                    && numDimensions <= MAX_DIMENSIONS, 
                    errors::InvalidArgument("ComputeSmoothWOp expects a valid number of dimension"));
                OP_REQUIRES(pContext, inInvRadii.shape().dim_size(0) == numDimensions, 
                    errors::InvalidArgument("ComputeSmoothWOp expects a number of dimensions in"
                    " inInvRadii equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inNeighbors.dims() == 2 && 
                    inNeighbors.shape().dim_size(1) == 2, 
                    errors::InvalidArgument("ComputeSmoothWOp expects a neighbor tensor with 2 dimensions "
                    "and 2 indices per neighbor."));
                OP_REQUIRES(pContext, inSampleNeighIndices.dims() == 1 && 
                    inSampleNeighIndices.shape().dim_size(0) == numSamples, 
                    errors::InvalidArgument("ComputeSmoothWOp expects a correct tensor for the sample indices."));

                //Get the gpu device.
                std::unique_ptr<mccnn::IGPUDevice> gpuDevice = make_unique<mccnn::TFGPUDevice>(pContext);

                //Create the output tensor.
                float* outputGPUPtr = nullptr;
                TensorShape outShape = TensorShape{numNeighbors};
                OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<float>
                    (0, pContext, outShape, &outputGPUPtr));

                //Compute the PDF  
                DIMENSION_SWITCH_CALL(numDimensions, mccnn::compute_smooth_weights_gpu,
                    gpuDevice, numSamples, numNeighbors, inInvRadiiGPUPtr, 
                    ptsGPUPtr, samplesGPUPtr, neighborsGPUPtr, sampleNeighIGPUPtr, 
                    outputGPUPtr);   
            }
    };

    /**
     *  Operation to compute the gradients of each point wrt the smooth weights.
     */
    class ComputeSmoothWPtGradsOp: public OpKernel{
        
        public:
        
            /**
             *  Constructor.
             *  @param  pContext    Constructor context of the operation.
             */
            explicit ComputeSmoothWPtGradsOp(
                OpKernelConstruction* pContext)
                :OpKernel(pContext){
            }
        
            /**
             *  Method to compute the operation.
             *  @param  pContext    Context of the operation.
             */
            void Compute(OpKernelContext * pContext) override{

                //Get the input tensors.
                const Tensor& inPts = pContext->input(0); 
                const Tensor& inSamples = pContext->input(1); 
                const Tensor& inNeighbors = pContext->input(2); 
                const Tensor& inSampleNeighIndices = pContext->input(3);
                const Tensor& inInvRadii = pContext->input(4);
                const Tensor& inSmoothWGrads = pContext->input(5);

                //Get variables from tensors.
                unsigned int numSamples = inSamples.shape().dim_size(0);
                unsigned int numPts = inPts.shape().dim_size(0);
                unsigned int numNeighbors = inNeighbors.shape().dim_size(0);
                unsigned int numDimensions = inPts.shape().dim_size(1);

                //Get the pointers to GPU data from the tensors.
                const float* ptsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inPts);
                const float* samplesGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inSamples);
                const int* neighborsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inNeighbors);
                const int* sampleNeighIGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inSampleNeighIndices);
                const float* inInvRadiiGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inInvRadii);
                const float* inSmoothWGradsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inSmoothWGrads);

                //Check for the correctness of the input.
                OP_REQUIRES(pContext, numDimensions >= MIN_DIMENSIONS 
                    && numDimensions <= MAX_DIMENSIONS, 
                    errors::InvalidArgument("ComputeSmoothWPtGradsOp expects a valid number of dimension"));
                OP_REQUIRES(pContext, inInvRadii.shape().dim_size(0) == numDimensions, 
                    errors::InvalidArgument("ComputeSmoothWPtGradsOp expects a number of dimensions in"
                    " inInvRadii equal to the number of dimensions in the input points"));
                OP_REQUIRES(pContext, inNeighbors.dims() == 2 && 
                    inNeighbors.shape().dim_size(1) == 2, 
                    errors::InvalidArgument("ComputeSmoothWPtGradsOp expects a neighbor tensor with 2 "
                    "dimensions and 2 indices per neighbor."));
                OP_REQUIRES(pContext, inSampleNeighIndices.dims() == 1 && 
                    inSampleNeighIndices.shape().dim_size(0) == numSamples, 
                    errors::InvalidArgument("ComputeSmoothWPtGradsOp expects a correct tensor for the sample indices."));
                OP_REQUIRES(pContext, inSmoothWGrads.shape().dim_size(0) == numNeighbors, 
                    errors::InvalidArgument("ComputeSmoothWPtGradsOp expects a number of smooth weights "
                    " gradients equal to the number of neighbors."));

                //Get the gpu device.
                std::unique_ptr<mccnn::IGPUDevice> gpuDevice = make_unique<mccnn::TFGPUDevice>(pContext);

                //Create the output tensor.
                float* outputGPUPtr = nullptr;
                float* output2GPUPtr = nullptr;
                TensorShape outShape = TensorShape{numPts, numDimensions};
                TensorShape outShape2 = TensorShape{numSamples, numDimensions};
                OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<float>
                    (0, pContext, outShape, &outputGPUPtr));
                OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<float>
                    (1, pContext, outShape2, &output2GPUPtr));

                //Compute the PDF  
                DIMENSION_SWITCH_CALL(numDimensions, mccnn::compute_smooth_weights_grads_gpu,
                    gpuDevice, numPts, numSamples, numNeighbors, inInvRadiiGPUPtr,
                    ptsGPUPtr, samplesGPUPtr, neighborsGPUPtr, sampleNeighIGPUPtr, 
                    inSmoothWGradsGPUPtr, outputGPUPtr, output2GPUPtr);   
            }
    };
}
            
REGISTER_KERNEL_BUILDER(Name("ComputeSmoothW").Device(DEVICE_GPU), mccnn::ComputeSmoothWOp);
REGISTER_KERNEL_BUILDER(Name("ComputeSmoothWWithPtGrads").Device(DEVICE_GPU), mccnn::ComputeSmoothWOp);
REGISTER_KERNEL_BUILDER(Name("ComputeSmoothWPtGrads").Device(DEVICE_GPU), mccnn::ComputeSmoothWPtGradsOp);