33 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
36 #include "../InternalHeaderCheck.h"
48 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder>
49 struct triangular_matrix_vector_product_trmv :
50 triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
52 #define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
53 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
54 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
55 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
56 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
57 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
58 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
61 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
62 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
63 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
64 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
65 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
66 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
76 #define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
77 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
78 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
80 IsLower = (Mode&Lower) == Lower, \
81 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
82 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
83 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
84 LowUp = IsLower ? Lower : Upper \
86 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
87 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
89 if (ConjLhs || IsZeroDiag) { \
90 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
91 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
94 Index size = (std::min)(_rows,_cols); \
95 Index rows = IsLower ? _rows : size; \
96 Index cols = IsLower ? size : _cols; \
98 typedef VectorX##EIGPREFIX VectorRhs; \
102 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
104 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
109 char trans, uplo, diag; \
110 BlasIndex m, n, lda, incx, incy; \
115 n = convert_index<BlasIndex>(size); \
116 lda = convert_index<BlasIndex>(lhsStride); \
118 incy = convert_index<BlasIndex>(resIncr); \
122 uplo = IsLower ? 'L' : 'U'; \
123 diag = IsUnitDiag ? 'U' : 'N'; \
126 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
129 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
131 if (size<(std::max)(rows,cols)) { \
132 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
135 y = _res + size*resIncr; \
137 m = convert_index<BlasIndex>(rows-size); \
138 n = convert_index<BlasIndex>(size); \
143 a = _lhs + size*lda; \
144 m = convert_index<BlasIndex>(size); \
145 n = convert_index<BlasIndex>(cols-size); \
147 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
165 #define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
166 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
167 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
169 IsLower = (Mode&Lower) == Lower, \
170 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
171 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
172 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
173 LowUp = IsLower ? Lower : Upper \
175 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
176 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
179 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
180 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
183 Index size = (std::min)(_rows,_cols); \
184 Index rows = IsLower ? _rows : size; \
185 Index cols = IsLower ? size : _cols; \
187 typedef VectorX##EIGPREFIX VectorRhs; \
191 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
193 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
198 char trans, uplo, diag; \
199 BlasIndex m, n, lda, incx, incy; \
204 n = convert_index<BlasIndex>(size); \
205 lda = convert_index<BlasIndex>(lhsStride); \
207 incy = convert_index<BlasIndex>(resIncr); \
210 trans = ConjLhs ? 'C' : 'T'; \
211 uplo = IsLower ? 'U' : 'L'; \
212 diag = IsUnitDiag ? 'U' : 'N'; \
215 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
218 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
220 if (size<(std::max)(rows,cols)) { \
221 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
224 y = _res + size*resIncr; \
225 a = _lhs + size*lda; \
226 m = convert_index<BlasIndex>(rows-size); \
227 n = convert_index<BlasIndex>(size); \
233 m = convert_index<BlasIndex>(size); \
234 n = convert_index<BlasIndex>(cols-size); \
236 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX)
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar)
#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX)
std::complex< double > dcomplex
std::complex< float > scomplex