10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
13 #include "../InternalHeaderCheck.h"
19 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder,
int Version=Specialized>
20 struct triangular_matrix_vector_product;
22 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
23 struct triangular_matrix_vector_product<
Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,
ColMajor,Version>
25 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
26 static constexpr
bool IsLower = ((Mode &
Lower) ==
Lower);
30 const RhsScalar* _rhs,
Index rhsIncr, ResScalar* _res,
Index resIncr,
31 const RhsScalar& alpha);
34 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
35 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
37 const RhsScalar* _rhs,
Index rhsIncr, ResScalar* _res,
Index resIncr,
const RhsScalar& alpha)
44 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
45 const LhsMap lhs(_lhs,
rows,
cols,OuterStride<>(lhsStride));
46 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
48 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
49 const RhsMap rhs(_rhs,
cols,InnerStride<>(rhsIncr));
50 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
52 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
55 typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
56 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
61 for (
Index k=0; k<actualPanelWidth; ++k)
64 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ?
i+1 :
i ) : pi;
65 Index r = IsLower ? actualPanelWidth-k : k+1;
66 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
67 res.segment(s,r) += (alpha * cjRhs.coeff(
i)) * cjLhs.col(
i).segment(s,r);
69 res.coeffRef(
i) += alpha * cjRhs.coeff(
i);
71 Index r = IsLower ?
rows - pi - actualPanelWidth : pi;
74 Index s = IsLower ? pi+actualPanelWidth : 0;
75 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
77 LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
78 RhsMapper(&rhs.coeffRef(pi), rhsIncr),
79 &
res.coeffRef(s), resIncr, alpha);
84 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
86 LhsMapper(&lhs.coeffRef(0,
size), lhsStride),
87 RhsMapper(&rhs.coeffRef(
size), rhsIncr),
88 _res, resIncr, alpha);
92 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
93 struct triangular_matrix_vector_product<
Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,
RowMajor,Version>
95 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
96 static constexpr
bool IsLower = ((Mode &
Lower) ==
Lower);
100 const RhsScalar* _rhs,
Index rhsIncr, ResScalar* _res,
Index resIncr,
101 const ResScalar& alpha);
104 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
105 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
106 ::run(
Index _rows,
Index _cols,
const LhsScalar* _lhs,
Index lhsStride,
107 const RhsScalar* _rhs,
Index rhsIncr, ResScalar* _res,
Index resIncr,
const ResScalar& alpha)
114 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
115 const LhsMap lhs(_lhs,
rows,
cols,OuterStride<>(lhsStride));
116 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
118 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
119 const RhsMap rhs(_rhs,
cols);
120 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
122 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
123 ResMap
res(_res,
rows,InnerStride<>(resIncr));
125 typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
126 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
128 for (
Index pi=0; pi<diagSize; pi+=PanelWidth)
130 Index actualPanelWidth = (
std::min)(PanelWidth, diagSize-pi);
131 for (
Index k=0; k<actualPanelWidth; ++k)
134 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ?
i+1 :
i);
135 Index r = IsLower ? k+1 : actualPanelWidth-k;
136 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
137 res.coeffRef(
i) += alpha * (cjLhs.row(
i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
139 res.coeffRef(
i) += alpha * cjRhs.coeff(
i);
141 Index r = IsLower ? pi :
cols - pi - actualPanelWidth;
144 Index s = IsLower ? 0 : pi + actualPanelWidth;
145 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
147 LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
148 RhsMapper(&rhs.coeffRef(s), rhsIncr),
149 &
res.coeffRef(pi), resIncr, alpha);
152 if(IsLower &&
rows>diagSize)
154 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
156 LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
157 RhsMapper(&rhs.coeffRef(0), rhsIncr),
158 &
res.coeffRef(diagSize), resIncr, alpha);
166 template<
int Mode,
int StorageOrder>
167 struct trmv_selector;
173 template<
int Mode,
typename Lhs,
typename Rhs>
174 struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
176 template<
typename Dest>
static void run(Dest& dst,
const Lhs &lhs,
const Rhs &rhs,
const typename Dest::Scalar& alpha)
178 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
184 template<
int Mode,
typename Lhs,
typename Rhs>
185 struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
187 template<
typename Dest>
static void run(Dest& dst,
const Lhs &lhs,
const Rhs &rhs,
const typename Dest::Scalar& alpha)
189 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
191 Transpose<Dest> dstT(dst);
194 ::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
204 template<
int Mode>
struct trmv_selector<Mode,
ColMajor>
206 template<
typename Lhs,
typename Rhs,
typename Dest>
207 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
209 typedef typename Lhs::Scalar LhsScalar;
210 typedef typename Rhs::Scalar RhsScalar;
211 typedef typename Dest::Scalar ResScalar;
213 typedef internal::blas_traits<Lhs> LhsBlasTraits;
214 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
215 typedef internal::blas_traits<Rhs> RhsBlasTraits;
216 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
219 typedef Map<Matrix<ResScalar,Dynamic,1>, Alignment> MappedDest;
221 add_const_on_value_type_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
222 add_const_on_value_type_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
224 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
225 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
226 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
230 constexpr
bool EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1;
232 constexpr
bool MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal;
234 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
237 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
239 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
242 evalToDest ? dest.data() : static_dest.data());
246 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
248 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
250 if(!alphaIsCompatible)
252 MappedDest(actualDestPtr, dest.size()).setZero();
253 compatibleAlpha = RhsScalar(1);
256 MappedDest(actualDestPtr, dest.size()) = dest;
259 internal::triangular_matrix_vector_product
261 LhsScalar, LhsBlasTraits::NeedToConjugate,
262 RhsScalar, RhsBlasTraits::NeedToConjugate,
264 ::run(actualLhs.rows(),actualLhs.cols(),
265 actualLhs.data(),actualLhs.outerStride(),
266 actualRhs.data(),actualRhs.innerStride(),
267 actualDestPtr,1,compatibleAlpha);
271 if(!alphaIsCompatible)
272 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
274 dest = MappedDest(actualDestPtr, dest.size());
280 dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
285 template<
int Mode>
struct trmv_selector<Mode,
RowMajor>
287 template<
typename Lhs,
typename Rhs,
typename Dest>
288 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
290 typedef typename Lhs::Scalar LhsScalar;
291 typedef typename Rhs::Scalar RhsScalar;
292 typedef typename Dest::Scalar ResScalar;
294 typedef internal::blas_traits<Lhs> LhsBlasTraits;
295 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
296 typedef internal::blas_traits<Rhs> RhsBlasTraits;
297 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
298 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
300 std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
301 std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
303 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
304 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
305 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
307 constexpr
bool DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1;
309 gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
312 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
316 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
318 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
320 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
323 internal::triangular_matrix_vector_product
325 LhsScalar, LhsBlasTraits::NeedToConjugate,
326 RhsScalar, RhsBlasTraits::NeedToConjugate,
328 ::run(actualLhs.rows(),actualLhs.cols(),
329 actualLhs.data(),actualLhs.outerStride(),
331 dest.data(),dest.innerStride(),
337 dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
const ImagReturnType imag() const
#define EIGEN_DONT_INLINE
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
#define EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH
const unsigned int RowMajorBit
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
bool is_exactly_one(const X &x)
bool is_exactly_zero(const X &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.