#include <torch/extension.h>

#include <vector>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x); \
    CHECK_IS_HALF_OR_BFLOAT(x)
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")

torch::Tensor butterfly_copy_cuda(
    torch::Tensor x,
    torch::Tensor twiddle_factors);

torch::Tensor butterfly_transcendental_cuda(
    torch::Tensor x
);

torch::Tensor const_compute_cuda(
    torch::Tensor x
);

torch::Tensor butterfly_copy(
    torch::Tensor x,
    torch::Tensor twiddle_factors
){
    // CHECK_INPUT(x);
    // CHECK_INPUT(twiddle_factors);

    // CHECK_SHAPE(x, B, H, N);
    // CHECK_SHAPE(twiddle_factors, H, N, 2);

    return butterfly_copy_cuda(x, twiddle_factors);
}


torch::Tensor butterfly_transcendental(
    torch::Tensor x
){
    // CHECK_INPUT(x);
    // CHECK_INPUT(twiddle_factors);

    // CHECK_SHAPE(x, B, H, N);
    // CHECK_SHAPE(twiddle_factors, H, N, 2);

    return butterfly_transcendental_cuda(x);
}


torch::Tensor const_compute(
    torch::Tensor x
){
    return const_compute_cuda(x);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{

    m.def("butterfly_copy_forward", &butterfly_copy, "Butterfly forward (CUDA)");
    m.def("butterfly_transcendental_forward", &butterfly_transcendental, "Butterfly backward (CUDA)");
    m.def("butterfly_const_forward", &const_compute, "Butterfly backward (CUDA)");
}