#include <torch/torch.h>


void cvt_bf16_e2m1(
    torch::Tensor&,
    const torch::Tensor&
);


void cvt_e2m1_bf16(
    torch::Tensor&,
    const torch::Tensor&
);


void forward_bf16(
    const torch::Tensor&,
    const torch::Tensor&,
    torch::Tensor&,
    torch::Tensor&,
    torch::Tensor&
);


void backward_bf16(
    const torch::Tensor&,
    const torch::Tensor&,
    torch::Tensor&,
    torch::Tensor&
);


void backward_t_bf16(
    const torch::Tensor&,
    const torch::Tensor&,
    torch::Tensor&,
    torch::Tensor&
);


void backward_qt_bf16(
    const torch::Tensor&,
    const torch::Tensor&,
    const torch::Tensor&,
    const float,
    torch::Tensor&,
    torch::Tensor&
);


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("cvt_bf16_e2m1", &cvt_bf16_e2m1, "cvt_bf16_e2m1");
    m.def("cvt_e2m1_bf16", &cvt_e2m1_bf16, "cvt_e2m1_bf16");
    m.def("forward_bf16", &forward_bf16, "forward_bf16");
    m.def("backward_bf16", &backward_bf16, "backward_bf16");
    m.def("backward_t_bf16", &backward_t_bf16, "backward_t_bf16");
    m.def("backward_qt_bf16", &backward_qt_bf16, "backward_qt_bf16");
}
