33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
36 #include "../InternalHeaderCheck.h"
43 template <
typename Scalar,
typename Index,
44 int Mode,
bool LhsIsTriangular,
45 int LhsStorageOrder,
bool ConjugateLhs,
46 int RhsStorageOrder,
bool ConjugateRhs,
48 struct product_triangular_matrix_matrix_trmm :
49 product_triangular_matrix_matrix<Scalar,Index,Mode,
50 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
51 RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
55 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
56 template <typename Index, int Mode, \
57 int LhsStorageOrder, bool ConjugateLhs, \
58 int RhsStorageOrder, bool ConjugateRhs> \
59 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
60 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \
61 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
62 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
63 EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
64 eigen_assert(resIncr == 1); \
65 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
66 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
67 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
68 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
82 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
83 template <typename Index, int Mode, \
84 int LhsStorageOrder, bool ConjugateLhs, \
85 int RhsStorageOrder, bool ConjugateRhs> \
86 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
87 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
90 IsLower = (Mode&Lower) == Lower, \
91 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
92 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
93 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
94 LowUp = IsLower ? Lower : Upper, \
95 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
99 Index _rows, Index _cols, Index _depth, \
100 const EIGTYPE* _lhs, Index lhsStride, \
101 const EIGTYPE* _rhs, Index rhsStride, \
102 EIGTYPE* res, Index resStride, \
103 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
105 Index diagSize = (std::min)(_rows,_depth); \
106 Index rows = IsLower ? _rows : diagSize; \
107 Index depth = IsLower ? diagSize : _depth; \
108 Index cols = _cols; \
110 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
111 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
114 if (rows != depth) { \
119 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
121 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
122 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
123 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
127 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
128 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
129 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
130 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
131 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
132 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \
138 char side = 'L', transa, uplo, diag = 'N'; \
141 BlasIndex m, n, lda, ldb; \
144 m = convert_index<BlasIndex>(diagSize); \
145 n = convert_index<BlasIndex>(cols); \
148 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
151 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
152 MatrixX##EIGPREFIX b_tmp; \
154 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
156 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
159 uplo = IsLower ? 'L' : 'U'; \
160 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
162 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
165 if ((conjA!=0) || (SetDiag==0)) { \
166 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
168 a_tmp.diagonal().setZero(); \
169 else if (IsUnitDiag) \
170 a_tmp.diagonal().setOnes();\
172 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
175 lda = convert_index<BlasIndex>(lhsStride); \
179 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
182 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
183 res_tmp=res_tmp+b_tmp; \
200 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
201 template <typename Index, int Mode, \
202 int LhsStorageOrder, bool ConjugateLhs, \
203 int RhsStorageOrder, bool ConjugateRhs> \
204 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
205 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
208 IsLower = (Mode&Lower) == Lower, \
209 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
210 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
211 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
212 LowUp = IsLower ? Lower : Upper, \
213 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
217 Index _rows, Index _cols, Index _depth, \
218 const EIGTYPE* _lhs, Index lhsStride, \
219 const EIGTYPE* _rhs, Index rhsStride, \
220 EIGTYPE* res, Index resStride, \
221 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
223 Index diagSize = (std::min)(_cols,_depth); \
224 Index rows = _rows; \
225 Index depth = IsLower ? _depth : diagSize; \
226 Index cols = IsLower ? diagSize : _cols; \
228 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
229 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
232 if (cols != depth) { \
236 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
238 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
239 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
240 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
244 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
245 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
246 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
247 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
248 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
249 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
255 char side = 'R', transa, uplo, diag = 'N'; \
258 BlasIndex m, n, lda, ldb; \
261 m = convert_index<BlasIndex>(rows); \
262 n = convert_index<BlasIndex>(diagSize); \
265 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
268 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
269 MatrixX##EIGPREFIX b_tmp; \
271 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
273 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
276 uplo = IsLower ? 'L' : 'U'; \
277 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
279 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
282 if ((conjA!=0) || (SetDiag==0)) { \
283 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
285 a_tmp.diagonal().setZero(); \
286 else if (IsUnitDiag) \
287 a_tmp.diagonal().setOnes();\
289 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
292 lda = convert_index<BlasIndex>(rhsStride); \
296 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
299 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
300 res_tmp=res_tmp+b_tmp; \
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)
int BLASFUNC() ctrmm(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *)
int BLASFUNC() ztrmm(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *)
int BLASFUNC() strmm(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *)
int BLASFUNC() dtrmm(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *)
std::complex< double > dcomplex
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
std::complex< float > scomplex