cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(token_module)

execute_process(
  COMMAND python -c "import torch; print(torch.utils.cmake_prefix_path)"
  OUTPUT_VARIABLE TorchCmakePrefixPath
  OUTPUT_STRIP_TRAILING_WHITESPACE
)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
    set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}")
    message(STATUS "Found ccache: ${CCACHE_PROGRAM}")
else()
    message(STATUS "ccache not found, compilation will not be cached.")
endif()

message(STATUS "Torch CMake prefix path: ${TorchCmakePrefixPath}")
list(APPEND CMAKE_PREFIX_PATH ${TorchCmakePrefixPath})

add_subdirectory(third-party/pybind11)

find_package(Torch REQUIRED)        # libTorch 찾기
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")

pybind11_add_module(token_module cpp/token_module.cpp cpp/kv_cache.cpp cpp/tree_manager.cpp)

target_compile_definitions(token_module PRIVATE)

if (CMAKE_BUILD_TYPE STREQUAL "Debug")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
  target_compile_definitions(token_module PRIVATE DEBUG PYBIND11_DETAILED_ERROR_MESSAGES)
else()
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
endif()

target_include_directories(token_module PRIVATE ${TORCH_INCLUDE_DIRS} cpp/)
target_link_libraries(token_module PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})

# openmp
find_package(OpenMP REQUIRED)
if(OpenMP_CXX_FOUND)
  target_compile_options(token_module PRIVATE ${OpenMP_CXX_FLAGS})
  target_link_libraries(token_module PRIVATE ${OpenMP_CXX_LIBRARIES})
endif()