
#include <torch/library.h>
#include <torch/csrc/autograd/custom_function.h>
#include "pytorch_npu_helper.hpp"
#include <torch/extension.h>

// Register forward implementation for NPU device
at::Tensor elementwise_sub_impl_npu(const at::Tensor& self, const at::Tensor& other) {
    // Create output memory
    at::Tensor result = at::empty_like(self);

    // Call aclnn interface for computation
    EXEC_NPU_CMD(aclnnElementwiseSub, self, other, result);
    return result;
}


// Register implementation for NPU device
TORCH_LIBRARY_IMPL(myops, PrivateUse1, m) {
    m.impl("elementwise_sub", &elementwise_sub_impl_npu);
}

// // Bind C++ interface to Python interface via pybind
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("elementwise_sub", &elementwise_sub_impl_npu, "x - y");
}
