from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(name='stnls',
      package_dir={"": "lib"},
      packages=find_packages("."),
      package_data={'': ['*.so']},
      include_package_data=True,
      ext_modules=[
          CUDAExtension('stnls_cuda', [
              'lib/csrc/nn/shared_nn_utils.cu',
              # 'lib/csrc/nn/shared_tile_kernels.cu',
              'lib/csrc/nn/topk_pwd.cpp',
              'lib/csrc/nn/topk_pwd_kernel.cu',
              'lib/csrc/nn/pfc.cpp',
              'lib/csrc/nn/pfc_kernel.cu',
              'lib/csrc/nn/temporal_inds.cpp',
              'lib/csrc/nn/temporal_inds_kernel.cu',
              'lib/csrc/nn/non_local_inds.cpp',
              'lib/csrc/nn/non_local_inds_kernel.cu',
              'lib/csrc/nn/accumulate_flow.cpp',
              'lib/csrc/nn/accumulate_flow_kernel.cu',
              'lib/csrc/nn/interpolate_inds.cpp',
              'lib/csrc/nn/interpolate_inds_kernel.cu',
              'lib/csrc/nn/unique_topk.cpp',
              'lib/csrc/nn/unique_topk_kernel.cu',
              'lib/csrc/nn/anchor_self.cpp',
              'lib/csrc/nn/anchor_self_kernel.cu',
              'lib/csrc/nn/jitter_unique_inds.cpp',
              'lib/csrc/nn/jitter_unique_inds_kernel.cu',
              'lib/csrc/search/shared_kernel.cu',
              # 'lib/csrc/search/nls_bilin2d.cu',
              # 'lib/csrc/search/nls_bilin3d.cu',
              'lib/csrc/search/non_local_search.cpp',
              'lib/csrc/search/non_local_search_kernel.cu',
              'lib/csrc/search/non_local_search_bilin2d_kernel.cu',
              'lib/csrc/search/non_local_search_offsets.cpp',
              'lib/csrc/search/non_local_search_offsets3d_kernel.cu',
              'lib/csrc/search/refinement.cpp',
              'lib/csrc/search/refinement_kernel.cu',
              'lib/csrc/search/paired_search.cpp',
              'lib/csrc/search/paired_search_kernel.cu',
              'lib/csrc/search/ref_bwd_kernel.cu',
              # 'lib/csrc/search/quadref.cpp',
              # 'lib/csrc/search/quadref_kernel.cu',
              'lib/csrc/search/window_search.cpp',
              'lib/csrc/search/window_search_kernel.cu',
              'lib/csrc/dev/search/prod_dists.cpp',
              'lib/csrc/dev/search/prod_dists_kernel.cu',
              'lib/csrc/dev/search/prod_refine.cpp',
              'lib/csrc/dev/search/prod_refine_kernel.cu',
              'lib/csrc/dev/search/prod_with_index_cuda.cpp',
              'lib/csrc/dev/search/prod_with_index_kernel.cu',
              'lib/csrc/dev/search/prod_pf_with_index_cuda.cpp',
              'lib/csrc/dev/search/prod_pf_with_index_kernel.cu',
              'lib/csrc/dev/search/l2_cuda.cpp',
              'lib/csrc/dev/search/l2_kernel.cu',
              'lib/csrc/dev/search/l2_dists_cuda.cpp',
              'lib/csrc/dev/search/l2_dists_kernel.cu',
              'lib/csrc/dev/search/l2_with_index_cuda.cpp',
              'lib/csrc/dev/search/l2_with_index_kernel.cu',
              'lib/csrc/dev/search/l2_search_with_heads.cpp',
              'lib/csrc/dev/search/l2_search_with_heads_kernel.cu',
              'lib/csrc/dev/search/prod_cuda.cpp',
              'lib/csrc/dev/search/prod_kernel.cu',
              'lib/csrc/dev/search/prod_search_with_heads.cpp',
              'lib/csrc/dev/search/prod_search_with_heads_kernel.cu',
              # 'lib/csrc/search/prod_search_patches_with_heads.cpp',
              # 'lib/csrc/search/prod_search_patches_with_heads_kernel.cu',
              'lib/csrc/tile_k/foldk_cuda.cpp',
              'lib/csrc/tile_k/foldk_kernel.cu',
              'lib/csrc/tile_k/unfoldk_cuda.cpp',
              'lib/csrc/tile_k/unfoldk_kernel.cu',
              'lib/csrc/tile/nlstack_bilin2d.cu',
              'lib/csrc/tile/nlstack_bilin3d.cu',
              'lib/csrc/tile/non_local_stack.cpp',
              'lib/csrc/tile/non_local_stack_kernel.cu',
              'lib/csrc/tile/non_local_stack_bilin2d_kernel.cu',
              'lib/csrc/tile/non_local_stack_bilin3d_kernel.cu',
              'lib/csrc/tile/fold_cuda.cpp',
              'lib/csrc/tile/fold_kernel.cu',
              'lib/csrc/tile/ifold_cuda.cpp',
              'lib/csrc/tile/ifold_kernel.cu',
              'lib/csrc/tile/ifoldz_cuda.cpp',
              'lib/csrc/tile/ifoldz_kernel.cu',
              'lib/csrc/tile/unfold_cuda.cpp',
              'lib/csrc/tile/unfold_kernel.cu',
              'lib/csrc/tile/iunfold_cuda.cpp',
              'lib/csrc/tile/iunfold_kernel.cu',
              'lib/csrc/tile/nlfold.cpp',
              'lib/csrc/tile/nlfold_kernel.cu',
              'lib/csrc/reducer/wpsum.cpp',
              'lib/csrc/reducer/wpsum_kernel.cu',
              'lib/csrc/reducer/iwpsum.cpp',
              'lib/csrc/reducer/iwpsum_kernel.cu',
              'lib/csrc/n3net_ops/mat_mult1.cpp',
              'lib/csrc/n3net_ops/mat_mult1_kernel.cu',
              'lib/csrc/pybind.cpp',
          ],
           extra_compile_args=['-w'],
          )
      ],
      cmdclass={'build_ext': BuildExtension},
)

