
#include "defines.hpp"
#include "tf_utils.hpp"
#include "tf_gpu_device.hpp"

/**
 *  Declaration of the tensorflow operation.
 */
REGISTER_OP("CollapseEdges")
    .Input("edge_sorted_ids: int32")
    .Input("edge_ids: int32")
    .Input("start_node_edge_ids: int32")
    .Output("node_collapse_ids: int32")
    .Output("node_scores_ids: int32")
    .SetShapeFn([](shape_inference::InferenceContext* pIC) {
        shape_inference::ShapeHandle outputDims = 
            pIC->MakeShape({pIC->Dim(pIC->input(2), 0)});
        pIC->set_output(0, outputDims);
        pIC->set_output(1, outputDims);
        return Status::OK();
    });

namespace mccnn{

    /**
     *  Operation to collapse edges.
     */
    class CollapseEdgesOp: public OpKernel{
        
        public:
        
            /**
             *  Constructor.
             *  @param  pContext    Constructor context of the operation.
             */
            explicit CollapseEdgesOp(
                OpKernelConstruction* pContext)
                :OpKernel(pContext){
            }
        
            /**
             *  Method to compute the operation.
             *  @param  pContext    Context of the operation.
             */
            void Compute(OpKernelContext * pContext) override{

                //Get the input tensors.
                const Tensor& inEdgeSortedIds = pContext->input(0); 
                const Tensor& inEdgeIds = pContext->input(1); 
                const Tensor& inSNodeEdgeIds = pContext->input(2); 

                //Get variables from tensors.
                unsigned int numNodes = inSNodeEdgeIds.shape().dim_size(0);
                unsigned int numEdges = inEdgeSortedIds.shape().dim_size(0);

                //Get the pointers to GPU data from the tensors.
                const int* edgeSortedIdsPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inEdgeSortedIds);
                const int* edgeIdsPtr = mccnn::tensorflow_utils::get_const_tensor_pointer<int>(inEdgeIds);

                //Check for the correctness of the input.
                OP_REQUIRES(pContext, inEdgeIds.shape().dim_size(0) == numEdges, 
                    errors::InvalidArgument("CollapseEdgesOp expects a the same number of edges and votes"));

                //Create the output tensor.
                int* outputNodeIds = nullptr;
                TensorShape outShape = TensorShape{numNodes};
                OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                    (0, pContext, outShape, &outputNodeIds));
                
                //Collapse edges.
                int numCollapsedNodes = 0;
                std::vector<int> usedNode(numNodes, 0);
                std::vector<int> voteIds(numNodes, -1);
                for(int i = 0 ; i < numEdges; ++i)
                {
                    int edgeId = edgeSortedIdsPtr[i];
                    int nodeId1 = edgeIdsPtr[edgeId*2];
                    int nodeId2 = edgeIdsPtr[edgeId*2 + 1];
                    if(usedNode[nodeId1] == 0 && usedNode[nodeId2] == 0){
                        voteIds[numCollapsedNodes] = edgeId;
                        outputNodeIds[nodeId1] = numCollapsedNodes;
                        outputNodeIds[nodeId2] = numCollapsedNodes;
                        usedNode[nodeId1] = 1;
                        usedNode[nodeId2] = 1;
                        numCollapsedNodes++;
                    }
                }

                //Add the uncollapsed nodes.
                for(int i = 0; i < numNodes; ++i)
                {
                    if(usedNode[i] == 0){
                        outputNodeIds[i] = numCollapsedNodes;
                        numCollapsedNodes++;
                    }
                }

                //Save the final votes ids.
                int* outputVotesIds = nullptr;
                TensorShape outShape2 = TensorShape{numCollapsedNodes};
                OP_REQUIRES_OK(pContext, mccnn::tensorflow_utils::allocate_output_tensor<int>
                    (1, pContext, outShape2, &outputVotesIds));
                memcpy((void*)outputVotesIds, (void*)&voteIds[0], sizeof(int)*numCollapsedNodes);
            }
    };
}
            
REGISTER_KERNEL_BUILDER(Name("CollapseEdges").Device(DEVICE_CPU), mccnn::CollapseEdgesOp);