#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <torch/script.h>


torch::Tensor fast_gemv(
    torch::Tensor x,
    torch::Tensor W,
    const int w_zero,
    unsigned int num_threads_per_row,
    unsigned int cols_per_warp);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fast_gemv", &fast_gemv, "fast_gemv");
    
}
