from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os

DEBUG_MODE = os.environ.get("DEBUG_MODE", "0") == "1"

opt_args = ["-O0", "-G", "-g", "-lineinfo"] if DEBUG_MODE else ["-O3"]

extra_args = [
    "-use_fast_math",
    "-prec-div=false",
    "-prec-sqrt=false",
    "-U__CUDA_NO_HALF_OPERATORS__",
    "-U__CUDA_NO_HALF_CONVERSIONS__",
    "-U__CUDA_NO_HALF2_OPERATORS__",
    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
    "--expt-relaxed-constexpr",
    "--expt-extended-lambda",
]

extra_args = opt_args + extra_args
print(f"===========================================================")
print(f"DEBUG_MODE: {DEBUG_MODE}")
print("\n".join(extra_args))
print(f"===========================================================")
setup(
    name="gemv_lib",
    version="0.0.1",
    license="Apache 2",
    ext_modules=[
        CUDAExtension(
            "fastgemv_lib",
            [
                "fastgemv/csrc/cuda/gemv.cu",
                "fastgemv/csrc/init.cpp",
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"] if not DEBUG_MODE else ["-O0", "-g", "-std=c++17"],
                "nvcc": extra_args,
            },
        ),
    ],
    packages=find_packages(include=["fastgemv", "fastgemv.*"]),
    package_data={
        "fastgemv": ["fastgemv/*.py"],
    },
    include_package_data=True,
    cmdclass={"build_ext": BuildExtension},
    install_requires=["numpy", "ninja"],
)
