from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='nsn_tools',
    ext_modules=[
        CUDAExtension(
            name='nsn_tools',
            sources=['src/csrc/nsn_tools.cpp',
                     'src/csrc/scale_adjustment.cu',
                     'src/csrc/scale_adjustment_1bit.cu',
                     'src/csrc/dist_argmin_half.cu',
                     'src/csrc/dist_argmin_half_batched.cu',
                     'src/csrc/dist_argmin_half_packed.cu',
                     'src/csrc/dist_argmin_half_packed_1bit.cu',
                     'src/csrc/restore_quantized.cu',
                     'src/csrc/restore_quantized_1bit.cu',
                     'src/csrc/quantized_weighted_sum.cu',
                     'src/csrc/quantized_weighted_sum_1bit.cu',
                     'src/csrc/quantized_weighted_sum_residual.cu',
                     'src/csrc/quantized_weighted_sum_residual_1bit.cu',
                     'src/csrc/window_rope_dot_product.cu',
                     'src/csrc/quantized_dot_product_fused_residual.cu',
                     'src/csrc/quantized_dot_product_fused_residual_1bit.cu',
                     'src/csrc/quantized_dot_product.cu',
                     'src/csrc/quantized_dot_product_fused.cu',
                     'src/csrc/quantized_dot_product_fused_1bit.cu',
                     'src/csrc/dq_kernels/scale_adjustment.cu',
                     'src/csrc/dq_kernels/scale_adjustment_1bit.cu',
                     'src/csrc/dq_kernels/restore_quantized.cu',
                     'src/csrc/dq_kernels/restore_quantized_1bit.cu',
                     'src/csrc/dq_kernels/dist_argmin_half_packed_1bit.cu',
                     'src/csrc/dq_kernels/dist_argmin_half_packed.cu',
                     'src/csrc/dq_kernels/quantized_dot_product_fused.cu',
                     'src/csrc/dq_kernels/quantized_dot_product_fused_1bit.cu',
                     'src/csrc/dq_kernels/quantized_dot_product_fused_residual.cu',
                     'src/csrc/dq_kernels/quantized_dot_product_fused_residual_1bit.cu',
                     'src/csrc/dq_kernels/quantized_weighted_sum.cu',
                     'src/csrc/dq_kernels/quantized_weighted_sum_1bit.cu',
                     'src/csrc/dq_kernels/quantized_weighted_sum_residual.cu',
                     'src/csrc/dq_kernels/quantized_weighted_sum_residual_1bit.cu',
                     ],
            include_dirs=['src/csrc'],
            extra_compile_args={ # From KIVI
                "cxx": [
                    "-O2"
                ],
                "nvcc": [
                    "-O2",
                    "-lineinfo"
                ],
            }
        )
    ],
    cmdclass={'build_ext': BuildExtension}
)
