from cupy.cuda.compiler import _hash_hexdigest, _get_arch
import os
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import subprocess
import cupy as cp
from dotenv import load_dotenv

# Helpers for interfaces with raw cuda kernels

load_dotenv()

project_dir = Path(__file__).resolve().parent.parent

_kernel_cache = project_dir / 'kernel_cache' / cp.__version__
os.makedirs(_kernel_cache, exist_ok=True)
_nvcc_path = os.getenv('NVCC_PATH', '/opt/cuda/bin/nvcc')
_arch = _get_arch()
_cupy_path = cp.__file__[:-11]
_includes = [f'-I{_cupy_path}{incl_dir}' for incl_dir in [
    "_core/include/",
    "_core/include/cupy/_cccl/cub",
    "_core/include/cupy/_cccl/thrust",
    "_core/include/cupy/_cccl/libcudacxx",
    "_core/include/cupy/_cuda/cuda-11"
]] + ["-I/opt/cuda/include"]

opt_flags = [
    '-O3',
    '--extra-device-vectorization',
    '--use_fast_math',
    f'-gencode=arch=compute_{_arch},code=sm_{_arch}',
    '--ptxas-options=--allow-expensive-optimizations=true',
]

def compile_kernel_with_replacements(name, replacements, debug):
    path = project_dir / f'model/kernels/raw_cuda/{name}.cu'
    tmp_file_path = os.path.join(_kernel_cache, f"tmp_{name}.cu")
    with open(path, 'r') as f:
        code = f.read()
        for (placeholder, replacement) in replacements:
            code = code.replace(f"__REPL_{placeholder}", replacement)
    cubin_hash = _hash_hexdigest(code.encode('ascii')).encode('ascii')
    cubin_path = os.path.join(_kernel_cache, f"{cubin_hash}.cubin")
    if not os.path.exists(cubin_path):
        with open(tmp_file_path, 'w') as out_f:
            out_f.write(code)
        cmd = ([_nvcc_path, '--std=c++14', '-o', cubin_path,
                "--cubin"] + opt_flags + _includes + [tmp_file_path])
        ret = subprocess.run(cmd, check=False).returncode
        if ret != 0:
            raise Exception(f"Failed to compile {name}")
        if not debug:
            os.remove(tmp_file_path)
    return cubin_path