
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension


setup(
    name='fselu',
    version='0.0.1',
    ext_modules=[
        CUDAExtension('fselu',
                      ['src/fselu.cpp', 'src/fselu_kernel.cu'],
                      extra_compile_args={'cxx':[], 'nvcc':['-arch=sm_70']},
                      include_dirs=['<path-to-cub>/cub-1.8.0/'])  # TODO: spcify the path to cub
    ],
    cmdclass={'build_ext': BuildExtension},
    install_requires=['torch']
)