
#include "tf_utils.hpp"
#include "tf_gpu_device.hpp"
#include "find_neighbors_topo.cuh"

/**
 *  Declaration of the tensorflow operation.
 */
REGISTER_OP("FindNeighborsTopo")
    .Input("points: float32")
    .Input("graph_neigh_list: int32")
    .Input("graph_node_start_ids: int32")
    .Output("neigh_ids: int32")
    .Attr("max_dist: float")
    .Attr("const_edge: int")
    .SetShapeFn([](shape_inference::InferenceContext* pIC) {
        shape_inference::ShapeHandle outputDims = pIC->MakeShape({pIC->Dim(pIC->input(0), 0), 32});
        pIC->set_output(0, outputDims);
        return Status::OK();
    });

namespace mccnn{
        
    /**
     *  Operation to find the neighboring points to a set of nodes.
     */
    class FindNeighborsTopoOp: public OpKernel{
        
        public:
        
            /**
             *  Constructor.
             *  @param  pContext    Constructor context of the operation.
             */
            explicit FindNeighborsTopoOp(
                OpKernelConstruction* pContext)
                :OpKernel(pContext){

                OP_REQUIRES_OK(pContext, pContext->GetAttr("max_dist", &maxDist_));
                OP_REQUIRES(pContext, maxDist_ > 0.0, 
                    errors::InvalidArgument("ComputeTopoDist requires positive max distance."));

                int constEdge;
                OP_REQUIRES_OK(pContext, pContext->GetAttr("const_edge", &constEdge));
                OP_REQUIRES(pContext, constEdge == 0 || constEdge == 1, 
                    errors::InvalidArgument("ComputeTopoDist requires const_edge equal to 1 or 0."));
                constEdge_ = constEdge == 1;
            }
        
            /**
             *  Method to compute the operation.
             *  @param  pContext    Context of the operation.
             */
            void Compute(OpKernelContext * pContext) override{

                //Get the input tensors.
                const Tensor& inPts = pContext->input(0); 
                const Tensor& inTopology = pContext->input(1); 
                const Tensor& inSampleTopoStartIndices = pContext->input(2);

                //Get variables from tensors.
                unsigned int numPts = inPts.shape().dim_size(0);
                unsigned int numSamples = inSampleTopoStartIndices.shape().dim_size(0);
                unsigned int numDimensions = inPts.shape().dim_size(1);
                unsigned int numTopoNeighs = inTopology.shape().dim_size(0);

                //Get the pointers to GPU data from the tensors.
                const float* ptsGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<float>(inPts);
                const int* topoGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inTopology);
                const int* sampleTopoIGPUPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inSampleTopoStartIndices);
                
                //Check for the correctness of the input.
                if(!constEdge_){
                    OP_REQUIRES(pContext, numSamples == numPts, 
                        errors::InvalidArgument("FindNeighborsTopo expects the same points as samples."));
                }
                OP_REQUIRES(pContext, numDimensions >= MIN_DIMENSIONS 
                    && numDimensions <= MAX_DIMENSIONS, 
                    errors::InvalidArgument("FindNeighborsTopo expects a valid number of dimension"));
                OP_REQUIRES(pContext, inTopology.dims() == 2 && 
                    inTopology.shape().dim_size(1) == 2, 
                    errors::InvalidArgument("FindNeighborsTopo expects a neighbor tensor with 2 dimensions "
                    "and 2 indices per neighbor."));

                //Get the gpu device.
                std::unique_ptr<mccnn::IGPUDevice> gpuDevice = make_unique<mccnn::TFGPUDevice>(pContext);

                //Create the output tensor.
                int* outputGPUPtr = nullptr;
                TensorShape outShape = TensorShape{numSamples, 32};
                OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                    (0, pContext, outShape, &outputGPUPtr));

                //Process all nodes at 6 hops distance.
                DIMENSION_SWITCH_CALL(numDimensions, mccnn::find_neighbors_topo_gpu,
                    gpuDevice, constEdge_, maxDist_, numSamples, ptsGPUPtr,
                    topoGPUPtr, sampleTopoIGPUPtr, outputGPUPtr);
            }

        private:

            /**Maximum distance.*/
            float   maxDist_;
            /**Boolean that indicates if we use a constant value for the edge.*/
            bool    constEdge_;
    };
}
            
REGISTER_KERNEL_BUILDER(Name("FindNeighborsTopo").Device(DEVICE_GPU), mccnn::FindNeighborsTopoOp);