cmake_minimum_required(VERSION 3.25)


set(CMAKE_CXX_COMPILER "/usr/bin/g++")
set(CMAKE_C_COMPILER "/usr/bin/gcc")
set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc")


if(NOT DEFINED CMAKE_PREFIX_PATH)
    set(CMAKE_PREFIX_PATH "/home/owner/.local/lib/python3.8/site-packages/torch/share/cmake")
endif()

if(NOT DEFINED CMAKE_BUILD_TYPE)
    set(CMAKE_BUILD_TYPE "RelWithDebInfo")
endif()


# # NOTE: Change when compiling on server. IDK if cmake flags would override this.
# # For local development.
# set(PYTHON_EXECUTABLE "/usr/bin/python3")



project(
    sfrp_torch
    VERSION 0.1
    DESCRIPTION "Storage-free random projections for PyTorch."
    LANGUAGES CUDA CXX)


include(FetchContent)


find_package(Python COMPONENTS Interpreter Development REQUIRED)
find_package(Torch REQUIRED)


FetchContent_Declare(
    pybind11
    GIT_REPOSITORY https://github.com/pybind/pybind11
    GIT_TAG        v2.11.1
)
FetchContent_MakeAvailable(pybind11)



set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")



# CUDA info.
set(CUDA_INCLUDE_DIR "/usr/local/cuda/include")
# set(CUDA_LIB_DIR "/usr/local/cuda/lib64")

include_directories(${CUDA_INCLUDE_DIR})



option(SET_TRUE_WHEN_ACTUALLY_BUILDING "Dumb hack to get my syntax highlighting to work." OFF)


if(SET_TRUE_WHEN_ACTUALLY_BUILDING)
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSET_TRUE_WHEN_ACTUALLY_BUILDING")
    set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DSET_TRUE_WHEN_ACTUALLY_BUILDING")
endif()

###############################################################################

list(
    APPEND
    sfrp_torch_src_cuda_files
    # 
    src/cuda/bernoulli/dense/dn_v2_alg3.cc
    src/cuda/bernoulli/dense/dn_v3_alg3.cc
    src/cuda/bernoulli/dense/trp_dn_v3_alg1.cc
    src/cuda/bernoulli/sparse/sp_v1_alg3.cc
    src/cuda/bernoulli/sparse/sp_v2_alg1.cc
    src/cuda/bernoulli/sparse/trp_sp_v2_alg1.cc
    src/cuda/bernoulli/bernoulli_util.cc
    src/cuda/util/args_validation.cc
    src/cuda/util/hashing.cc
    src/cuda/util/misc_util.cc
    #
    # Dummy files, needed to get my editor to work properly with header files in these folders.
    src/cuda/util/device/dummy.cc
)


find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")

add_library(sfrp_torch_cuda SHARED
    ${sfrp_torch_src_cuda_files}
)
target_include_directories(sfrp_torch_cuda PUBLIC ./src)
target_include_directories(sfrp_torch_cuda PUBLIC ${CUDA_INCLUDE_DIR})

set_property(TARGET sfrp_torch_cuda PROPERTY CXX_STANDARD 17)
target_link_libraries(sfrp_torch_cuda PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} Python::Python)
target_compile_options(sfrp_torch_cuda PRIVATE -std=c++17 -O3 -fno-math-errno -fno-trapping-math -fPIC)

if(SET_TRUE_WHEN_ACTUALLY_BUILDING)
    set_source_files_properties(${sfrp_torch_src_cuda_files} PROPERTIES LANGUAGE CUDA)
endif()

###############################################################################

pybind11_add_module(sfrp_torch src/bindings.cc)
set_property(TARGET sfrp_torch PROPERTY CXX_STANDARD 17)
target_include_directories(sfrp_torch PUBLIC ./src)
target_link_libraries(sfrp_torch PUBLIC sfrp_torch_cuda)
target_link_libraries(sfrp_torch PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} Python::Python)
target_compile_options(sfrp_torch PRIVATE -std=c++17 -O3 -fno-math-errno -fno-trapping-math -fPIC)
