summaryrefslogtreecommitdiff
path: root/sci-libs/miopen/files/miopen-5.7.1-fix-miopendriver-gemm.patch
diff options
context:
space:
mode:
Diffstat (limited to 'sci-libs/miopen/files/miopen-5.7.1-fix-miopendriver-gemm.patch')
-rw-r--r--sci-libs/miopen/files/miopen-5.7.1-fix-miopendriver-gemm.patch74
1 files changed, 74 insertions, 0 deletions
diff --git a/sci-libs/miopen/files/miopen-5.7.1-fix-miopendriver-gemm.patch b/sci-libs/miopen/files/miopen-5.7.1-fix-miopendriver-gemm.patch
new file mode 100644
index 000000000000..859667f3da30
--- /dev/null
+++ b/sci-libs/miopen/files/miopen-5.7.1-fix-miopendriver-gemm.patch
@@ -0,0 +1,74 @@
+Fix uninitialized variable in MIOpenDriver gemm and restore gemmfp16 for testing
+Upstream bug: https://github.com/ROCmSoftwarePlatform/MIOpen/issues/2505
+--- a/driver/driver.hpp
++++ b/driver/driver.hpp
+@@ -141,7 +141,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
+ printf("Usage: ./driver *base_arg* *other_args*\n");
+ printf("Supported Base Arguments: conv[fp16|int8|bfp16], CBAInfer[fp16], "
+ "pool[fp16], lrn[fp16], "
+- "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], "
++ "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
+ "tensorop[fp16], reduce[fp16,fp64]\n");
+ exit(0); // NOLINT (concurrency-mt-unsafe)
+ }
+@@ -160,7 +160,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
+ arg != "CBAInfer" && arg != "CBAInferfp16" && arg != "pool" && arg != "poolfp16" &&
+ arg != "lrn" && arg != "lrnfp16" && arg != "activ" && arg != "activfp16" &&
+ arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" && arg != "bnormfp16" &&
+- arg != "rnn" && arg != "rnnfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" &&
++ arg != "rnn" && arg != "rnnfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" &&
+ arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
+ arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "--version")
+ {
+--- a/driver/gemm_driver.hpp
++++ b/driver/gemm_driver.hpp
+@@ -207,6 +207,19 @@ int GemmDriver<T>::GetandSetData()
+ gemm_desc.strideB = gemm_desc.k * gemm_desc.n;
+ gemm_desc.strideC = gemm_desc.m * gemm_desc.n;
+
++ if constexpr (std::is_same_v<T, float>)
++ {
++ gemm_desc.dataType = miopenFloat;
++ }
++ else if constexpr (std::is_same_v<T, float16>)
++ {
++ gemm_desc.dataType = miopenHalf;
++ }
++ else
++ {
++ static_assert(!"unsupported type");
++ }
++
+ return (0);
+ }
+
+@@ -230,9 +243,9 @@ int GemmDriver<T>::AllocateBuffersAndCopy()
+ a = std::vector<T>(a_sz);
+ b = std::vector<T>(b_sz);
+ #if GEMM_DRIVER_DEBUG
+- c = std::vector<T>(c_sz, 1.);
++ c = std::vector<T>(c_sz, static_cast<T>(1.));
+ #else
+- c = std::vector<T>(c_sz, 0.);
++ c = std::vector<T>(c_sz, static_cast<T>(0.));
+ #endif
+ chost = c;
+
+--- a/driver/main.cpp
++++ b/driver/main.cpp
+@@ -125,11 +125,10 @@ int main(int argc, char* argv[])
+ {
+ drv = new GemmDriver<float>();
+ }
+-// TODO half is not supported in gemm
+-// else if(base_arg == "gemmfp16")
+-// {
+-// drv = new GemmDriver<float16>();
+-// }
++ else if(base_arg == "gemmfp16")
++ {
++ drv = new GemmDriver<float16>();
++ }
+ #endif
+ else if(base_arg == "bnorm")
+ {