from setuptools import setup, find_packages
import torch.utils.cpp_extension as torch_cpp_ext
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
import pathlib, torch
import re

setup_dir = os.path.dirname(os.path.realpath(__file__))
HERE = pathlib.Path(__file__).absolute().parent
torch_version = torch.__version__

def remove_unwanted_pytorch_nvcc_flags():
    REMOVE_NVCC_FLAGS = [
        '-D__CUDA_NO_HALF_OPERATORS__',
        '-D__CUDA_NO_HALF_CONVERSIONS__',
        '-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
        '-D__CUDA_NO_HALF2_OPERATORS__',
    ]
    for flag in REMOVE_NVCC_FLAGS:
        try:
            torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
        except ValueError:
            pass

def get_cuda_arch_flags():
    dev   = torch.cuda.current_device()
    major, minor = torch.cuda.get_device_capability(dev)
    arch_str = f"{major}{minor}a"
    cc = major*10 + minor

    flags = [
        '-gencode', 'arch=compute_100a,code=sm_100a',
        '-gencode', 'arch=compute_120a,code=sm_120a',
        f'-DTARGET_CUDA_ARCH={cc}',
    ]
    return flags

def third_party_cmake():
    import subprocess, sys, shutil

    cmake = shutil.which('cmake')
    if cmake is None:
            raise RuntimeError('Cannot find CMake executable.')

    retcode = subprocess.call([cmake, HERE])
    if retcode != 0:
        sys.stderr.write("Error: CMake configuration failed.\n")
        sys.exit(1)

if __name__ == '__main__':
    assert torch.cuda.is_available(), "CUDA is not available!"
    device = torch.cuda.current_device()
    print(f"Current device: {torch.cuda.get_device_name(device)}")
    print(f"Current CUDA capability: {torch.cuda.get_device_capability(device)}")
    assert torch.cuda.get_device_capability(device)[0] >= 10, f"CUDA capability must be >= 10.0, yours is {torch.cuda.get_device_capability(device)}" #FIXME: restrict to sm_120 at this point?

    print(f"PyTorch version: {torch_version}")
    m = re.match(r'^(\d+)\.(\d+)', torch_version)
    if not m:
        raise RuntimeError(f"Cannot parse PyTorch version '{torch_version}'")
    major, minor = map(int, m.groups())
    if major < 2 or (major == 2 and minor < 8):
        raise RuntimeError(f"PyTorch version must be >= 2.8, but found {torch_version}")

    third_party_cmake()
    remove_unwanted_pytorch_nvcc_flags()
    setup(
        name='qutlass',
        packages=find_packages(),
        ext_modules=[
            CUDAExtension(
                name='qutlass._CUDA',
                sources=[
                    'qutlass/csrc/bindings.cpp',
                    'qutlass/csrc/gemm_mx.cu',
                    'qutlass/csrc/quartet.cu',
                    'qutlass/csrc/fused_quantize.cu',
                    'qutlass/csrc/fused_quantize_bwd.cu',
                ],
                include_dirs=[
                    os.path.join(setup_dir, 'qutlass/csrc/include'),
                    os.path.join(setup_dir, 'third_party/cutlass/include'),
                    os.path.join(setup_dir, 'third_party/cutlass/tools/util/include'),
                ],
                extra_compile_args={
                    'cxx': [],
                    'nvcc': get_cuda_arch_flags(),
                },
                libraries=["cuda", "cudart"],
                extra_link_args=[
                '-lcudart',
                '-lcuda',
                ]
            )
        ],
        cmdclass={
            'build_ext': BuildExtension
        }
    )
