# pybind11 use the variable PYTHON_EXECUTABLE(case sensitive) to detect python
set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE})

add_library(chunk_attn_lib STATIC)
target_sources(chunk_attn_lib PRIVATE
        "task.h"
        "chunk.h"
        "chunk.cpp"
        "chunk_allocator.h"
        "chunk_allocator.cpp"
        "attention.h"
        "attention.cpp"
        "trace.h"
        "logging.h"
        "logging.cpp"
        "str_utils.h"
        "str_utils.cpp")

target_include_directories(chunk_attn_lib PUBLIC
        ${CMAKE_CURRENT_SOURCE_DIR}
        ${TORCH_INCLUDE_DIRS}
        ${SPDLOG_INCLUDE_DIR})
target_link_libraries(chunk_attn_lib PUBLIC
        ${TORCH_LIBRARIES}
        ${TORCH_PYTHON_LIBRARY})
message(STATUS "TORCH_LIBRARIES=${TORCH_LIBRARIES}")
target_link_directories(chunk_attn_lib PUBLIC
        ${TORCH_INSTALL_PREFIX}/lib)
# add this because torch will link mkl automatically if you install MKL
target_link_directories(chunk_attn_lib PUBLIC
        ${MKL_H}/../lib/intel64)

if (USE_MKL)
    find_package(MKL CONFIG REQUIRED)
    find_package(OpenMP)
    message(STATUS "${MKL_IMPORTED_TARGETS}")
    message("MKL_H=${MKL_H}")
    target_compile_options(chunk_attn_lib PUBLIC $<TARGET_PROPERTY:MKL::MKL,INTERFACE_COMPILE_OPTIONS>)
    target_sources(chunk_attn_lib PRIVATE
            "kernel_cpu_mkl.h"
            "kernel_cpu_mkl.cpp"
            "kernel_cpu_tls.h"
            "small_vector.h"
            "spin_lock.h")
    target_include_directories(chunk_attn_lib PUBLIC
            $<TARGET_PROPERTY:MKL::MKL,INTERFACE_INCLUDE_DIRECTORIES>)
    target_link_libraries(chunk_attn_lib PUBLIC
            $<LINK_ONLY:MKL::MKL>
            OpenMP::OpenMP_CXX
    )
    target_compile_definitions(chunk_attn_lib PUBLIC USE_MKL=1)
endif ()
if (USE_CUDA)
    enable_language(CUDA)
    target_sources(chunk_attn_lib PRIVATE
            kernel_cuda.cu
            kernel_cuda.h
    )
    target_compile_definitions(chunk_attn_lib PUBLIC USE_CUDA=1)
    if(CMAKE_BUILD_TYPE MATCHES Debug)
        target_compile_options(chunk_attn_lib PUBLIC $<$<COMPILE_LANGUAGE:CUDA>:-G>)
    endif()
    if (CMAKE_BUILD_TYPE MATCHES Release)
        set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3")
    endif()
endif ()

pybind11_add_module(chunk_attn_c)
target_sources(chunk_attn_c PRIVATE
        "python_exports.cpp"
)
target_link_libraries(chunk_attn_c PUBLIC chunk_attn_lib)

if (TARGET BUILD_INFO)
    add_dependencies(${TARGET} BUILD_INFO)
endif ()
