summaryrefslogtreecommitdiff
path: root/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch
diff options
context:
space:
mode:
Diffstat (limited to 'sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch')
-rw-r--r--sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch35
1 files changed, 35 insertions, 0 deletions
diff --git a/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch b/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch
new file mode 100644
index 000000000000..2c65987acd85
--- /dev/null
+++ b/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch
@@ -0,0 +1,35 @@
+Disables aotriton download when both USE_FLASH_ATTENTION and USE_MEM_EFF_ATTENTION cmake flags are OFF
+Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197
+--- a/cmake/Dependencies.cmake
++++ b/cmake/Dependencies.cmake
+@@ -1334,7 +1334,9 @@ if(USE_ROCM)
+ message(STATUS "Disabling Kernel Assert for ROCm")
+ endif()
+
+- include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
++ if(USE_FLASH_ATTENTION)
++ include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
++ endif()
+ if(USE_CUDA)
+ caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
+ endif()
+--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
++++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+@@ -21,7 +21,7 @@
+ #include <cmath>
+ #include <functional>
+
+-#if USE_ROCM
++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION)
+ #include <aotriton/flash.h>
+ #endif
+
+@@ -186,7 +186,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
+ // Check that the gpu is capable of running flash attention
+ using sm80 = SMVersion<8, 0>;
+ using sm90 = SMVersion<9, 0>;
+-#if USE_ROCM
++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION)
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
+ auto dprops = at::cuda::getCurrentDeviceProperties();