diff options
Diffstat (limited to 'sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch')
-rw-r--r-- | sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch b/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch new file mode 100644 index 000000000000..2c65987acd85 --- /dev/null +++ b/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch @@ -0,0 +1,35 @@ +Disables aotriton download when both USE_FLASH_ATTENTION and USE_MEM_EFF_ATTENTION cmake flags are OFF +Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197 +--- a/cmake/Dependencies.cmake ++++ b/cmake/Dependencies.cmake +@@ -1334,7 +1334,9 @@ if(USE_ROCM) + message(STATUS "Disabling Kernel Assert for ROCm") + endif() + +- include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) ++ if(USE_FLASH_ATTENTION) ++ include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) ++ endif() + if(USE_CUDA) + caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) + endif() +--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp ++++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +@@ -21,7 +21,7 @@ + #include <cmath> + #include <functional> + +-#if USE_ROCM ++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION) + #include <aotriton/flash.h> + #endif + +@@ -186,7 +186,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug + // Check that the gpu is capable of running flash attention + using sm80 = SMVersion<8, 0>; + using sm90 = SMVersion<9, 0>; +-#if USE_ROCM ++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION) + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); |