from setuptools import setup
import torch
import os, glob
from torch.utils.cpp_extension import CUDAExtension, CppExtension, BuildExtension

def get_extensions():
    ext_name = 'atomic_max_custom'
    extensions = []
    include_dirs = [os.path.abspath('./csrc')]

    define_macros = []
    extra_compile_args = {'cxx': []}

    if torch.cuda.is_available():
        print(f'Compiling {ext_name} with CUDA')
        define_macros += [('WITH_CUDA', None)]
        sources = glob.glob('./cuda_ops/*.cu')
        extra_compile_args['nvcc'] = ['-O2']
        
        
        extension = CUDAExtension(
            name=ext_name,
            sources=sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args
        )
    else:
        assert False, 'CUDA is not available'
        print(f'Compiling {ext_name} without CUDA')
        sources = glob.glob('./csrc/*.cpp') 
        
        
        extension = CppExtension(
            name=ext_name,
            sources=sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args
        )

    extensions.append(extension)
    return extensions

setup(
    name='atomic_max_extension',
    ext_modules=get_extensions(),
    cmdclass={'build_ext': BuildExtension}
)
