TriangularMatrixMatrix_BLAS.h
Go to the documentation of this file.
1 /*
2  Copyright (c) 2011, Intel Corporation. All rights reserved.
3 
4  Redistribution and use in source and binary forms, with or without modification,
5  are permitted provided that the following conditions are met:
6 
7  * Redistributions of source code must retain the above copyright notice, this
8  list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright notice,
10  this list of conditions and the following disclaimer in the documentation
11  and/or other materials provided with the distribution.
12  * Neither the name of Intel Corporation nor the names of its contributors may
13  be used to endorse or promote products derived from this software without
14  specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23  ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27  ********************************************************************************
28  * Content : Eigen bindings to BLAS F77
29  * Triangular matrix * matrix product functionality based on ?TRMM.
30  ********************************************************************************
31 */
32 
33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
35 
36 #include "../InternalHeaderCheck.h"
37 
38 namespace Eigen {
39 
40 namespace internal {
41 
42 
43 template <typename Scalar, typename Index,
44  int Mode, bool LhsIsTriangular,
45  int LhsStorageOrder, bool ConjugateLhs,
46  int RhsStorageOrder, bool ConjugateRhs,
47  int ResStorageOrder>
48 struct product_triangular_matrix_matrix_trmm :
49  product_triangular_matrix_matrix<Scalar,Index,Mode,
50  LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
51  RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
52 
53 
54 // try to go to BLAS specialization
55 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
56 template <typename Index, int Mode, \
57  int LhsStorageOrder, bool ConjugateLhs, \
58  int RhsStorageOrder, bool ConjugateRhs> \
59 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
60  LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \
61  static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
62  const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
63  EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
64  eigen_assert(resIncr == 1); \
65  product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
66  LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
67  RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
68  _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
69  } \
70 };
71 
72 EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
73 EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
76 EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
77 EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
80 
81 // implements col-major += alpha * op(triangular) * op(general)
82 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
83 template <typename Index, int Mode, \
84  int LhsStorageOrder, bool ConjugateLhs, \
85  int RhsStorageOrder, bool ConjugateRhs> \
86 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
87  LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
88 { \
89  enum { \
90  IsLower = (Mode&Lower) == Lower, \
91  SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
92  IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
93  IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
94  LowUp = IsLower ? Lower : Upper, \
95  conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
96  }; \
97 \
98  static void run( \
99  Index _rows, Index _cols, Index _depth, \
100  const EIGTYPE* _lhs, Index lhsStride, \
101  const EIGTYPE* _rhs, Index rhsStride, \
102  EIGTYPE* res, Index resStride, \
103  EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
104  { \
105  Index diagSize = (std::min)(_rows,_depth); \
106  Index rows = IsLower ? _rows : diagSize; \
107  Index depth = IsLower ? diagSize : _depth; \
108  Index cols = _cols; \
109 \
110  typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
111  typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
112 \
113 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
114  if (rows != depth) { \
115 \
116  /* FIXME handle mkl_domain_get_max_threads */ \
117  /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
118 \
119  if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
120  /* Most likely no benefit to call TRMM or GEMM from BLAS */ \
121  product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
122  LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
123  _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
124  /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
125  } else { \
126  /* Make sense to call GEMM */ \
127  Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
128  MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
129  BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
130  gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
131  general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
132  rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \
133 \
134  /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
135  } \
136  return; \
137  } \
138  char side = 'L', transa, uplo, diag = 'N'; \
139  EIGTYPE *b; \
140  const EIGTYPE *a; \
141  BlasIndex m, n, lda, ldb; \
142 \
143 /* Set m, n */ \
144  m = convert_index<BlasIndex>(diagSize); \
145  n = convert_index<BlasIndex>(cols); \
146 \
147 /* Set trans */ \
148  transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
149 \
150 /* Set b, ldb */ \
151  Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
152  MatrixX##EIGPREFIX b_tmp; \
153 \
154  if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
155  b = b_tmp.data(); \
156  ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
157 \
158 /* Set uplo */ \
159  uplo = IsLower ? 'L' : 'U'; \
160  if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
161 /* Set a, lda */ \
162  Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
163  MatrixLhs a_tmp; \
164 \
165  if ((conjA!=0) || (SetDiag==0)) { \
166  if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
167  if (IsZeroDiag) \
168  a_tmp.diagonal().setZero(); \
169  else if (IsUnitDiag) \
170  a_tmp.diagonal().setOnes();\
171  a = a_tmp.data(); \
172  lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
173  } else { \
174  a = _lhs; \
175  lda = convert_index<BlasIndex>(lhsStride); \
176  } \
177  /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
178 /* call ?trmm*/ \
179  BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
180 \
181 /* Add op(a_triangular)*b into res*/ \
182  Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
183  res_tmp=res_tmp+b_tmp; \
184  } \
185 };
186 
187 #ifdef EIGEN_USE_MKL
188 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm)
189 EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
190 EIGEN_BLAS_TRMM_L(float, float, f, strmm)
191 EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
192 #else
193 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_)
194 EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_)
195 EIGEN_BLAS_TRMM_L(float, float, f, strmm_)
196 EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
197 #endif
198 
199 // implements col-major += alpha * op(general) * op(triangular)
200 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
201 template <typename Index, int Mode, \
202  int LhsStorageOrder, bool ConjugateLhs, \
203  int RhsStorageOrder, bool ConjugateRhs> \
204 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
205  LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
206 { \
207  enum { \
208  IsLower = (Mode&Lower) == Lower, \
209  SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
210  IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
211  IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
212  LowUp = IsLower ? Lower : Upper, \
213  conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
214  }; \
215 \
216  static void run( \
217  Index _rows, Index _cols, Index _depth, \
218  const EIGTYPE* _lhs, Index lhsStride, \
219  const EIGTYPE* _rhs, Index rhsStride, \
220  EIGTYPE* res, Index resStride, \
221  EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
222  { \
223  Index diagSize = (std::min)(_cols,_depth); \
224  Index rows = _rows; \
225  Index depth = IsLower ? _depth : diagSize; \
226  Index cols = IsLower ? diagSize : _cols; \
227 \
228  typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
229  typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
230 \
231 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
232  if (cols != depth) { \
233 \
234  int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
235 \
236  if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
237  /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
238  product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
239  LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
240  _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
241  /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
242  } else { \
243  /* Make sense to call GEMM */ \
244  Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
245  MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
246  BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
247  gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
248  general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
249  rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
250 \
251  /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
252  } \
253  return; \
254  } \
255  char side = 'R', transa, uplo, diag = 'N'; \
256  EIGTYPE *b; \
257  const EIGTYPE *a; \
258  BlasIndex m, n, lda, ldb; \
259 \
260 /* Set m, n */ \
261  m = convert_index<BlasIndex>(rows); \
262  n = convert_index<BlasIndex>(diagSize); \
263 \
264 /* Set trans */ \
265  transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
266 \
267 /* Set b, ldb */ \
268  Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
269  MatrixX##EIGPREFIX b_tmp; \
270 \
271  if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
272  b = b_tmp.data(); \
273  ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
274 \
275 /* Set uplo */ \
276  uplo = IsLower ? 'L' : 'U'; \
277  if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
278 /* Set a, lda */ \
279  Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
280  MatrixRhs a_tmp; \
281 \
282  if ((conjA!=0) || (SetDiag==0)) { \
283  if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
284  if (IsZeroDiag) \
285  a_tmp.diagonal().setZero(); \
286  else if (IsUnitDiag) \
287  a_tmp.diagonal().setOnes();\
288  a = a_tmp.data(); \
289  lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
290  } else { \
291  a = _rhs; \
292  lda = convert_index<BlasIndex>(rhsStride); \
293  } \
294  /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
295 /* call ?trmm*/ \
296  BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
297 \
298 /* Add op(a_triangular)*b into res*/ \
299  Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
300  res_tmp=res_tmp+b_tmp; \
301  } \
302 };
303 
304 #ifdef EIGEN_USE_MKL
305 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm)
306 EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
307 EIGEN_BLAS_TRMM_R(float, float, f, strmm)
308 EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
309 #else
310 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_)
311 EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_)
312 EIGEN_BLAS_TRMM_R(float, float, f, strmm_)
313 EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_)
314 #endif
315 } // end namespace internal
316 
317 } // end namespace Eigen
318 
319 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)
int BLASFUNC() ctrmm(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *)
int BLASFUNC() ztrmm(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *)
int BLASFUNC() strmm(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *)
int BLASFUNC() dtrmm(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *)
: InteropHeaders
Definition: Core:139
std::complex< double > dcomplex
Definition: MKL_support.h:127
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
std::complex< float > scomplex
Definition: MKL_support.h:128