summaryrefslogtreecommitdiff
path: root/sci-libs/tensorflow/files/tensorflow-2.13.0-0009-fix-sparse-transpose-op2.patch
blob: 26b61ac3e5fd3ac0ced3c278ebc6f7d77e2b66d4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
From 7961afc6f67a4278409f7bdb710180daeb91c106 Mon Sep 17 00:00:00 2001
From: wangjiezhe <wangjiezhe@gmail.com>
Date: Sun, 26 Nov 2023 10:31:31 +0800
Subject: [PATCH 09/12] fix sparse transpose op2

---
 tensorflow/core/kernels/sparse/transpose_op.cc | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/tensorflow/core/kernels/sparse/transpose_op.cc b/tensorflow/core/kernels/sparse/transpose_op.cc
index 4fe99013480..a247d417504 100644
--- a/tensorflow/core/kernels/sparse/transpose_op.cc
+++ b/tensorflow/core/kernels/sparse/transpose_op.cc
@@ -208,6 +208,13 @@ Status CSRSparseMatrixTranspose<Device, T>::operator()(
   return OkStatus();
 }
 
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+template struct CSRSparseMatrixTranspose<GPUDevice, float>;
+template struct CSRSparseMatrixTranspose<GPUDevice, double>;
+template struct CSRSparseMatrixTranspose<GPUDevice, std::complex<float>>;
+template struct CSRSparseMatrixTranspose<GPUDevice, std::complex<double>>;
+#endif
+
 // CPU kernel for transposing a single component of a CSR SparseMatrix.
 template <typename T>
 struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
-- 
2.41.0