cmake_minimum_required(VERSION 3.19)

set(CUSTOM_TRANSPORT_VERSION_MAJOR 3)
set(CUSTOM_TRANSPORT_VERSION_MINOR 0)
set(CUSTOM_TRANSPORT_VERSION_PATCH 0)

if (DEFINED ENV{LIBFABRIC_HOME})
  set(LIBFABRIC_HOME_DEFAULT $ENV{LIBFABRIC_HOME})
else()
  set(LIBFABRIC_HOME_DEFAULT "/usr/local/libfabric")
endif()

if (DEFINED ENV{UCX_HOME})
  set(UCX_HOME_DEFAULT $ENV{UCX_HOME})
else()
  set(UCX_HOME_DEFAULT "/usr/local/ucx")
endif()

if (DEFINED ENV{CUDA_HOME})
  set(CUDA_HOME_DEFAULT $ENV{CUDA_HOME})
else()
  set(CUDA_HOME_DEFAULT "/usr/local/cuda")
endif()

if (DEFINED ENV{NVSHMEM_USE_GDRCOPY})
  set(NVSHMEM_USE_GDRCOPY_DEFAULT $ENV{NVSHMEM_USE_GDRCOPY})
else()
  set(NVSHMEM_USE_GDRCOPY_DEFAULT ON)
endif()

option(NVSHMEM_USE_GDRCOPY "Enable compilation of GDRCopy offload paths for atomics in remote transports" ${NVSHMEM_USE_GDRCOPY_DEFAULT})
option(NVSHMEM_BUILD_IBDEVX_TRANSPORT "Enable compilation of the ibdevx transport" OFF)
option(NVSHMEM_BUILD_IBGDA_TRANSPORT "Enable compilation of the ibgda transport" OFF)
option(NVSHMEM_BUILD_IBRC_TRANSPORT "Enable compilation of the ibrc transport" OFF)
option(NVSHMEM_BUILD_LIBFABRIC_TRANSPORT "Enable compilation of the libfabric transport" OFF)
option(NVSHMEM_BUILD_UCX_TRANSPORT "Enable compilation of the UCX transport" OFF)

set(CUDA_HOME ${CUDA_HOME_DEFAULT} CACHE PATH "path to CUDA installation")
set(GDRCOPY_HOME ${GDRCOPY_HOME_DEFAULT} CACHE PATH "path to GDRCOPY installation")
set(LIBFABRIC_HOME ${LIBFABRIC_HOME_DEFAULT} CACHE PATH "path to libfabric installation")
set(UCX_HOME ${UCX_HOME_DEFAULT} CACHE PATH "path to UCX installation")

# Allow users to set the CUDA toolkit through the env.
if(NOT CUDAToolkit_Root AND NOT CMAKE_CUDA_COMPILER)
  message(STATUS "CUDA_HOME: ${CUDA_HOME}")
  set(CUDAToolkit_Root ${CUDA_HOME} CACHE PATH "Root of Cuda Toolkit." FORCE)
  set(CMAKE_CUDA_COMPILER "${CUDA_HOME}/bin/nvcc" CACHE PATH "Root of Cuda Toolkit." FORCE)
endif()

project(
  NVSHMEM_TRANSPORTS
  LANGUAGES CUDA CXX C
  VERSION ${CUSTOM_TRANSPORT_VERSION_MAJOR}.${CUSTOM_TRANSPORT_VERSION_MINOR}.${CUSTOM_TRANSPORT_VERSION_PATCH}
)

find_package(CUDAToolkit)

if(NVSHMEM_BUILD_IBDEVX_TRANSPORT OR NVSHMEM_BUILD_IBGDA_TRANSPORT)
  find_library(MLX5_lib NAMES mlx5)
endif()

if(NVSHMEM_USE_GDRCOPY)
  find_path(
    GDRCOPY_INCLUDE gdrapi.h
    PATHS /usr/local/gdrcopy /usr/local/gdrdrv ${CMAKE_SOURCE_DIR}
    HINTS ${CMAKE_SOURCE_DIR} /usr/local/gdrcopy /usr/local/gdrdrv ${GDRCOPY_HOME}
    PATH_SUFFIXES include_gdrcopy include
  )
endif()

add_subdirectory(common)

macro(nvshmem_transport_set_base_config LIBNAME)
  target_compile_definitions(${LIBNAME}
    PRIVATE $<$<CONFIG:Debug>:_NVSHMEM_DEBUG;NVSHMEM_IBGDA_DEBUG>
    $<IF:$<STREQUAL:"${CMAKE_HOST_SYSTEM_PROCESSOR}","x86_64">,NVSHMEM_X86_64,>
    $<IF:$<STREQUAL:"${CMAKE_HOST_SYSTEM_PROCESSOR}","ppc64le">,__STDC_LIMIT_MACROS;__STDC_CONSTANT_MACROS;NVSHMEM_PPC64LE,>
    $<IF:$<STREQUAL:"${CMAKE_HOST_SYSTEM_PROCESSOR}","aarch64">,NVSHMEM_AARCH64,>
  )
endmacro()

macro(nvshmem_transport_set_gdr_config TRANSPORT_NAME)
  target_include_directories(nvshmem_transport_gdr_common PUBLIC ${GDRCOPY_INCLUDE})
  target_link_libraries(${TRANSPORT_NAME} PRIVATE nvshmem_transport_gdr_common)
endmacro()

macro(nvshmem_transport_set_ib_config TRANSPORT_NAME)
  target_link_libraries(${TRANSPORT_NAME} PRIVATE nvshmem_transport_ib_common)
endmacro()

macro(nvshmem_transport_set_mlx5_config TRANSPORT_NAME)
  target_include_directories(nvshmem_transport_gdr_common PUBLIC ${GDRCOPY_INCLUDE})
    target_link_libraries(${TRANSPORT_NAME} PRIVATE MLX5_lib)
endmacro()

macro(nvshmem_add_transport TRANSPORT_NAME SOURCE_LIST CUDA_FLAG GDR_FLAG IB_FLAG MLX5_FLAG)
  add_library(${TRANSPORT_NAME} SHARED)

  nvshmem_transport_set_base_config(${TRANSPORT_NAME})

  target_sources(${TRANSPORT_NAME} PRIVATE ${SOURCE_LIST})
  target_include_directories(${TRANSPORT_NAME} PRIVATE
                             ${CMAKE_SOURCE_DIR}/common
                             ${CMAKE_SOURCE_DIR}/include
                             ${CUDAToolkit_INCLUDE_DIRS}
  )
  set_target_properties(${TRANSPORT_NAME}
    PROPERTIES PREFIX ""
    VERSION ${CUSTOM_TRANSPORT_VERSION_MAJOR}.${CUSTOM_TRANSPORT_VERSION_MINOR}.${CUSTOM_TRANSPORT_VERSION_PATCH}
    SOVERSION ${CUSTOM_TRANSPORT_VERSION_MAJOR}
  )
  target_link_options(${TRANSPORT_NAME} PRIVATE
                      "-Wl,--version-script=${CMAKE_SOURCE_DIR}/nvshmem_transport.sym")

  target_link_libraries(${TRANSPORT_NAME} PRIVATE nvshmem_transport_common)

  if(${CUDA_FLAG})
    target_link_libraries(${TRANSPORT_NAME} PRIVATE CUDA::cudart_static)
  endif()

  if(${GDR_FLAG} AND ${NVSHMEM_USE_GDRCOPY})
    nvshmem_transport_set_gdr_config(${TRANSPORT_NAME})
  endif()

  if(${IB_FLAG})
    nvshmem_transport_set_ib_config(${TRANSPORT_NAME})
  endif()

  if(${MLX5_FLAG})
    nvshmem_transport_set_mlx5_config(${TRANSPORT_NAME})
  endif()

  install(TARGETS ${TRANSPORT_NAME}
    LIBRARY DESTINATION lib
  )
endmacro()

if(NVSHMEM_BUILD_IBDEVX_TRANSPORT)
  add_subdirectory(ibdevx)
endif()

if(NVSHMEM_BUILD_IBGDA_TRANSPORT)
  add_subdirectory(ibgda)
endif()

if(NVSHMEM_BUILD_IBRC_TRANSPORT)
  add_subdirectory(ibrc)
endif()

if(NVSHMEM_BUILD_LIBFABRIC_TRANSPORT)
  find_library(FABRIC_lib NAMES fabric HINTS "${LIBFABRIC_HOME}/lib" "${LIBFABRIC_HOME}/lib64")
  add_subdirectory(libfabric)
endif()

if(NVSHMEM_BUILD_UCX_TRANSPORT)
  find_package(UCX PATHS ${UCX_HOME} REQUIRED)
  add_subdirectory(ucx)
endif()
