#include <torch/extension.h>

#include <vector>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x); \
    CHECK_IS_HALF_OR_BFLOAT(x)
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")


std::vector<torch::Tensor> butterfly_cuda(
    torch::Tensor x,
    torch::Tensor d_f_T,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag
);


std::vector<torch::Tensor> butterfly_padded_cuda(
    torch::Tensor x,
    torch::Tensor d_f_T,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag
);

torch::Tensor butterfly_ifft_cuda(
    torch::Tensor x_real,
    torch::Tensor x_imag,
    torch::Tensor d_f_T,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag
);

torch::Tensor butterfly_ifft_padded_cuda(
    torch::Tensor x_real,
    torch::Tensor x_imag,
    torch::Tensor d_f,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag,
    int N
);

std::vector<torch::Tensor> butterfly(
    torch::Tensor x,
    torch::Tensor d_f_T,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag
){
    // printf("butterfly_32\n");
    // CHECK_INPUT(x);
    // CHECK_INPUT(twiddle_factors);

    // CHECK_SHAPE(x, B, H, N);
    // CHECK_SHAPE(twiddle_factors, H, N, 2);

    return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
}

torch::Tensor butterfly_ifft(
    torch::Tensor x_real,
    torch::Tensor x_imag,
    torch::Tensor d_f_T,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag
){
    // printf("butterfly_32\n");
    // CHECK_INPUT(x);
    // CHECK_INPUT(twiddle_factors);

    // CHECK_SHAPE(x, B, H, N);
    // CHECK_SHAPE(twiddle_factors, H, N, 2);

    return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag);
}

std::vector<torch::Tensor> butterfly_padded(
    torch::Tensor x,
    torch::Tensor d_f_T,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag
){
    // printf("butterfly_32\n");
    // CHECK_INPUT(x);
    // CHECK_INPUT(twiddle_factors);

    // CHECK_SHAPE(x, B, H, N);
    // CHECK_SHAPE(twiddle_factors, H, N, 2);

    return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
}

torch::Tensor butterfly_ifft_padded(
    torch::Tensor x_real,
    torch::Tensor x_imag,
    torch::Tensor d_f,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag,
    int N
){
    // printf("butterfly_32\n");
    // CHECK_INPUT(x);
    // CHECK_INPUT(twiddle_factors);

    // CHECK_SHAPE(x, B, H, N);
    // CHECK_SHAPE(twiddle_factors, H, N, 2);

    return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("butterfly_forward", &butterfly, "Butterfly backward (CUDA)");
    m.def("butterfly_padded_forward", &butterfly_padded, "Butterfly padded (CUDA)");
    m.def("butterfly_ifft_forward", &butterfly_ifft, "Butterfly backward (CUDA)");
    m.def("butterfly_ifft_padded_forward", &butterfly_ifft_padded, "Butterfly ifft padded (CUDA)");
}