33 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
34 #define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
36 #include "../InternalHeaderCheck.h"
42 template <
typename Index,
typename Scalar,
int AStorageOrder,
bool ConjugateA,
int ResStorageOrder,
int UpLo>
43 struct general_matrix_matrix_rankupdate :
44 general_matrix_matrix_triangular_product<
45 Index,Scalar,AStorageOrder,ConjugateA,Scalar,AStorageOrder,ConjugateA,ResStorageOrder,1,UpLo,BuiltIn> {};
49 #define EIGEN_BLAS_RANKUPDATE_SPECIALIZE(Scalar) \
50 template <typename Index, int LhsStorageOrder, bool ConjugateLhs, \
51 int RhsStorageOrder, bool ConjugateRhs, int UpLo> \
52 struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,ConjugateLhs, \
53 Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,1,UpLo,Specialized> { \
54 static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
55 const Scalar* rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
57 if ( lhs==rhs && ((UpLo&(Lower|Upper))==UpLo) ) { \
58 general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
59 ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
61 general_matrix_matrix_triangular_product<Index, \
62 Scalar, LhsStorageOrder, ConjugateLhs, \
63 Scalar, RhsStorageOrder, ConjugateRhs, \
64 ColMajor, 1, UpLo, BuiltIn> \
65 ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resIncr,resStride,alpha,blocking); \
77 #define EIGEN_BLAS_RANKUPDATE_R(EIGTYPE, BLASTYPE, BLASFUNC) \
78 template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
79 struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
81 IsLower = (UpLo&Lower) == Lower, \
82 LowUp = IsLower ? Lower : Upper, \
83 conjA = ((AStorageOrder==ColMajor) && ConjugateA) ? 1 : 0 \
85 static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
86 const EIGTYPE* , Index , EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& ) \
90 BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
91 char uplo=((IsLower) ? 'L' : 'U'), trans=((AStorageOrder==RowMajor) ? 'T':'N'); \
93 BLASFUNC(&uplo, &trans, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), lhs, &lda, (const BLASTYPE*)&numext::real_ref(beta), res, &ldc); \
98 #define EIGEN_BLAS_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, BLASFUNC) \
99 template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
100 struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
102 IsLower = (UpLo&Lower) == Lower, \
103 LowUp = IsLower ? Lower : Upper, \
104 conjA = (((AStorageOrder==ColMajor) && ConjugateA) || ((AStorageOrder==RowMajor) && !ConjugateA)) ? 1 : 0 \
106 static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
107 const EIGTYPE* , Index , EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& ) \
109 typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
111 BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
112 char uplo=((IsLower) ? 'L' : 'U'), trans=((AStorageOrder==RowMajor) ? 'C':'N'); \
113 RTYPE alpha_, beta_; \
114 const EIGTYPE* a_ptr; \
116 alpha_ = alpha.real(); \
121 Map<const MatrixType, 0, OuterStride<> > mapA(lhs,n,k,OuterStride<>(lhsStride)); \
122 a = mapA.conjugate(); \
123 lda = a.outerStride(); \
126 BLASFUNC(&uplo, &trans, &n, &k, &alpha_, (BLASTYPE*)a_ptr, &lda, &beta_, (BLASTYPE*)res, &ldc); \
#define EIGEN_BLAS_RANKUPDATE_R(EIGTYPE, BLASTYPE, BLASFUNC)
#define EIGEN_BLAS_RANKUPDATE_SPECIALIZE(Scalar)
int BLASFUNC() ssyrk(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *)
int BLASFUNC() dsyrk(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *)