#include <torch/extension.h>

#include <vector>

torch::Tensor domafilter_cuda_forward(
    torch::Tensor input,
    torch::Tensor memory,
    torch::Tensor mappings,
    torch::Tensor keys);

torch::Tensor domafilter_cuda_backward(
    torch::Tensor input,
    torch::Tensor memory,
    torch::Tensor mappings,
    torch::Tensor keys,
    torch::Tensor output_grad);

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor domafilter_forward(
    torch::Tensor input,
    torch::Tensor memory,
    torch::Tensor mappings,
    torch::Tensor keys) {
  CHECK_INPUT(input);
  CHECK_INPUT(memory);
  CHECK_INPUT(mappings);
  CHECK_INPUT(keys);
  return domafilter_cuda_forward(input, memory, mappings, keys);
};

torch::Tensor domafilter_backward(
     torch::Tensor input,
    torch::Tensor memory,
    torch::Tensor mappings,
    torch::Tensor keys,
    torch::Tensor output_grad) {
  CHECK_INPUT(input);
  CHECK_INPUT(memory);
  CHECK_INPUT(mappings);
  CHECK_INPUT(keys);
  CHECK_INPUT(output_grad);
  return domafilter_cuda_backward(input, memory, mappings, keys, output_grad);
};

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &domafilter_forward, "DoMaFilter CUDA forward");
  m.def("backward", &domafilter_backward, "DoMaFilter CUDA backward");
}