# when importing torch, the lib10.so gets loaded into the kernel
# the lib10.so is needed for hinv_cuda, otherwise it will fail
import torch

import os
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension

capability = torch.cuda.get_device_capability(0)
cuda_cc = f'{capability[0]}{capability[1]}'

name = f'cuda_cadam_sm{cuda_cc}'
cuda_ext = CUDAExtension(
    name=name,
    sources=[
        'cadam.cpp',
        'cadam_update.cu',
        'cadam_tools.cu',
        'cadam_symm_block_quant.cu',
        'cadam_symm_block_quant_inv.cu',
        'cadam_asymm_block_quant.cu',
        'cadam_asymm_block_quant_inv.cu',
        'cadam_asymm_global_quant.cu',
        'cadam_asymm_global_quant_inv.cu',
    ])
setup(name=name, ext_modules=[cuda_ext], cmdclass={'build_ext': BuildExtension})
