# from setuptools import setup
# from torch.utils.cpp_extension import BuildExtension, CUDAExtension

# setup(
#     name='atomic_max',
#     ext_modules=[
#         CUDAExtension('atomic_max_custom', [
#             './cuda_ops/atomic_max_kernel.cu'
#         ]),
#     ],
#     cmdclass={
#         'build_ext': BuildExtension
#     }
# )


from setuptools import setup
import torch
import os, glob
from torch.utils.cpp_extension import CUDAExtension, CppExtension, BuildExtension

def get_extensions():
    ext_name = 'atomic_max_custom'  # 自定义扩展名
    extensions = []
    include_dirs = [os.path.abspath('./csrc')]  # 指定包含目录

    define_macros = []
    extra_compile_args = {'cxx': []}

    # 检测CUDA支持
    if torch.cuda.is_available():
        print(f'Compiling {ext_name} with CUDA')
        define_macros += [('WITH_CUDA', None)]
        sources = glob.glob('./cuda_ops/*.cu')  # 指定CUDA源文件位置
        extra_compile_args['nvcc'] = ['-O2']  # 对nvcc编译器使用优化选项
        
        # 定义CUDA扩展
        extension = CUDAExtension(
            name=ext_name,
            sources=sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args
        )
    else:
        assert False, 'CUDA is not available'
        print(f'Compiling {ext_name} without CUDA')
        sources = glob.glob('./csrc/*.cpp')  # 指定C++源文件位置
        
        # 定义C++扩展
        extension = CppExtension(
            name=ext_name,
            sources=sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args
        )

    extensions.append(extension)
    return extensions

# 设置扩展模块
setup(
    name='atomic_max_extension',
    ext_modules=get_extensions(),
    cmdclass={'build_ext': BuildExtension}
)
