#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdio>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>


#include "sqllm.h"
#include "lutgemm.h"

#include "dp.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
	m.def("anyprec_dequant", &anyprec_dequant, "ANYPREC dequantization");

	m.def("anyprec_gemv", &anyprec_gemv, "ANYPREC GEMV");
	m.def("anyprec_gemv_sel", &anyprec_gemv_sel, "ANYPREC GEMV");
	m.def("anyprec_gemv_sel_fake", &anyprec_gemv_sel_fake, "ANYPREC GEMV");
	m.def("anyprec_gemv_sel_two", &anyprec_gemv_sel_two, "ANYPREC GEMV");
	m.def("gemvNormTH", &gemvNormTH, "ANYPREC GEMV");

	m.def("gemvNormTHq", &gemvNormTHq, "ANYPREC GEMV");
	m.def("gemvNormTHk", &gemvNormTHk, "ANYPREC GEMV");
	m.def("gemvNormTHv", &gemvNormTHv, "ANYPREC GEMV");
	m.def("gemvNormTHg", &gemvNormTHg, "ANYPREC GEMV");
	m.def("gemvNormTHu", &gemvNormTHu, "ANYPREC GEMV");
	m.def("gemvNormTHqkv", &gemvNormTHqkv, "ANYPREC GEMV");
	m.def("gemvNormTHgu", &gemvNormTHgu, "ANYPREC GEMV");

	m.def("normTHq", &normTHq, "ANYPREC GEMV");
	m.def("normTHk", &normTHk, "ANYPREC GEMV");
	m.def("normTHv", &normTHv, "ANYPREC GEMV");
	m.def("normTHg", &normTHg, "ANYPREC GEMV");
	m.def("normTHu", &normTHu, "ANYPREC GEMV");
	m.def("normTHqkv", &normTHqkv, "ANYPREC GEMV");
	m.def("normTHgu", &normTHgu, "ANYPREC GEMV");

	m.def("gemvNormTH2", &gemvNormTH2, "ANYPREC GEMV");
	m.def("gemvNormTH3", &gemvNormTH3, "ANYPREC GEMV");
	m.def("gemvNormTH3Full", &gemvNormTH3Full, "ANYPREC GEMV");
	m.def("normTH", &normTH, "ANYPREC GEMV");
	m.def("normTH2", &normTH2, "ANYPREC GEMV");
	m.def("lnNormTH2", &lnNormTH2, "ANYPREC GEMV");
	m.def("lnGemvNormTH", &lnGemvNormTH, "ANYPREC GEMV");

	m.def("fakeTrigger", &fakeTrigger, "ANYPREC GEMV");
	m.def("create_streamNevent_full", &create_streamNevent_full, "ANYPREC GEMV");
	m.def("lutgemm_gemv", &lutgemm_gemv, "LUTGEMM GEMV");
	m.def("sqllm_gemv", &sqllm_gemv, "SQLLM GEMV");
}
