#include <torch/extension.h>

#include <vector>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")

// CUDA Methods
at::Tensor dimmedian_idx_forward_cuda(
    torch::Tensor X,
    torch::Tensor rowIndices,
    torch::Tensor colIndices,
    torch::Tensor edge_weights,
    const int64_t N,
    const int THREADS = 1024);

// C++ to CUDA Methods
at::Tensor dimmedian_idx_forward(
    torch::Tensor X,
    torch::Tensor rowIndices,
    torch::Tensor colIndices,
    torch::Tensor edge_weights,
    const int64_t N)
{
  CHECK_CUDA(X);
  CHECK_CUDA(rowIndices);
  CHECK_CUDA(colIndices);
  CHECK_CUDA(edge_weights);
  return dimmedian_idx_forward_cuda(X, rowIndices, colIndices, edge_weights, N);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
  m.def("dimmedian_idx", &dimmedian_idx_forward, "dimension wise median idx forward");
} 