import os
import torch
import setuptools
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension

cpu_dir = os.path.join('src', 'torch_wnn', 'cpu')
cuda_dir = os.path.join('src', 'torch_wnn', 'cuda')

ext_modules = [CppExtension(filename[:-3], [os.path.join(cpu_dir, filename)]) for filename in os.listdir(cpu_dir)]
if torch.cuda.is_available():
    for filename in os.listdir(cuda_dir):
        if filename[-2:] == 'cc':
            module_name = filename[:-3]
            kernel_filename = module_name + '_kernel.cu'
            ext_modules.append(CUDAExtension(module_name, [os.path.join(cuda_dir, filename), os.path.join(cuda_dir, kernel_filename)]))

setup(
    name='torch_wnn',
    ext_modules=ext_modules,
    cmdclass={'build_ext': BuildExtension},
    package_dir={"": "src"},
    packages=setuptools.find_packages(where="src"),
    version="1.0.1",
    author="Alan T. L. Bacellar",
    author_email="alanbacellar@gmail.com",
    description="Weightless Neural Networks (WNN) module for pytorch",
    url="https://github.com/Alantlb/torch-wnn",
)