TriangularMatrixVector.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
12 
13 #include "../InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 namespace internal {
18 
19 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
20 struct triangular_matrix_vector_product;
21 
22 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
23 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
24 {
25  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
26  static constexpr bool IsLower = ((Mode & Lower) == Lower);
27  static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag;
28  static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag;
29  static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30  const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr,
31  const RhsScalar& alpha);
32 };
33 
34 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
35 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
36  ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
37  const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha)
38  {
39  static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
40  Index size = (std::min)(_rows,_cols);
41  Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
42  Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
43 
44  typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
45  const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
46  typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
47 
48  typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
49  const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
50  typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
51 
52  typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
53  ResMap res(_res,rows);
54 
55  typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
56  typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
57 
58  for (Index pi=0; pi<size; pi+=PanelWidth)
59  {
60  Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
61  for (Index k=0; k<actualPanelWidth; ++k)
62  {
63  Index i = pi + k;
64  Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
65  Index r = IsLower ? actualPanelWidth-k : k+1;
66  if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
67  res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
68  if (HasUnitDiag)
69  res.coeffRef(i) += alpha * cjRhs.coeff(i);
70  }
71  Index r = IsLower ? rows - pi - actualPanelWidth : pi;
72  if (r>0)
73  {
74  Index s = IsLower ? pi+actualPanelWidth : 0;
75  general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
76  r, actualPanelWidth,
77  LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
78  RhsMapper(&rhs.coeffRef(pi), rhsIncr),
79  &res.coeffRef(s), resIncr, alpha);
80  }
81  }
82  if((!IsLower) && cols>size)
83  {
84  general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
85  rows, cols-size,
86  LhsMapper(&lhs.coeffRef(0,size), lhsStride),
87  RhsMapper(&rhs.coeffRef(size), rhsIncr),
88  _res, resIncr, alpha);
89  }
90  }
91 
92 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
93 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
94 {
95  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
96  static constexpr bool IsLower = ((Mode & Lower) == Lower);
97  static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag;
98  static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag;
99  static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
100  const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr,
101  const ResScalar& alpha);
102 };
103 
104 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
105 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
106  ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
107  const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
108  {
109  static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
110  Index diagSize = (std::min)(_rows,_cols);
111  Index rows = IsLower ? _rows : diagSize;
112  Index cols = IsLower ? diagSize : _cols;
113 
114  typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
115  const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
116  typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
117 
118  typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
119  const RhsMap rhs(_rhs,cols);
120  typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
121 
122  typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
123  ResMap res(_res,rows,InnerStride<>(resIncr));
124 
125  typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
126  typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
127 
128  for (Index pi=0; pi<diagSize; pi+=PanelWidth)
129  {
130  Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
131  for (Index k=0; k<actualPanelWidth; ++k)
132  {
133  Index i = pi + k;
134  Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
135  Index r = IsLower ? k+1 : actualPanelWidth-k;
136  if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
137  res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
138  if (HasUnitDiag)
139  res.coeffRef(i) += alpha * cjRhs.coeff(i);
140  }
141  Index r = IsLower ? pi : cols - pi - actualPanelWidth;
142  if (r>0)
143  {
144  Index s = IsLower ? 0 : pi + actualPanelWidth;
145  general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
146  actualPanelWidth, r,
147  LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
148  RhsMapper(&rhs.coeffRef(s), rhsIncr),
149  &res.coeffRef(pi), resIncr, alpha);
150  }
151  }
152  if(IsLower && rows>diagSize)
153  {
154  general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
155  rows-diagSize, cols,
156  LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
157  RhsMapper(&rhs.coeffRef(0), rhsIncr),
158  &res.coeffRef(diagSize), resIncr, alpha);
159  }
160  }
161 
162 
166 template<int Mode,int StorageOrder>
167 struct trmv_selector;
168 
169 } // end namespace internal
170 
171 namespace internal {
172 
173 template<int Mode, typename Lhs, typename Rhs>
174 struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
175 {
176  template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
177  {
178  eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
179 
180  internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha);
181  }
182 };
183 
184 template<int Mode, typename Lhs, typename Rhs>
185 struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
186 {
187  template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
188  {
189  eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
190 
191  Transpose<Dest> dstT(dst);
192  internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
193  (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
194  ::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
195  }
196 };
197 
198 } // end namespace internal
199 
200 namespace internal {
201 
202 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
203 
204 template<int Mode> struct trmv_selector<Mode,ColMajor>
205 {
206  template<typename Lhs, typename Rhs, typename Dest>
207  static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
208  {
209  typedef typename Lhs::Scalar LhsScalar;
210  typedef typename Rhs::Scalar RhsScalar;
211  typedef typename Dest::Scalar ResScalar;
212 
213  typedef internal::blas_traits<Lhs> LhsBlasTraits;
214  typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
215  typedef internal::blas_traits<Rhs> RhsBlasTraits;
216  typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
217  constexpr int Alignment = (std::min)(int(AlignedMax), int(internal::packet_traits<ResScalar>::size));
218 
219  typedef Map<Matrix<ResScalar,Dynamic,1>, Alignment> MappedDest;
220 
221  add_const_on_value_type_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
222  add_const_on_value_type_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
223 
224  LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
225  RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
226  ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
227 
228  // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
229  // on, the other hand it is good for the cache to pack the vector anyways...
230  constexpr bool EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1;
231  constexpr bool ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex);
232  constexpr bool MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal;
233 
234  gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
235 
236  bool alphaIsCompatible = (!ComplexByReal) || numext::is_exactly_zero(numext::imag(actualAlpha));
237  bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
238 
239  RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
240 
241  ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
242  evalToDest ? dest.data() : static_dest.data());
243 
244  if(!evalToDest)
245  {
246  #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
247  Index size = dest.size();
248  EIGEN_DENSE_STORAGE_CTOR_PLUGIN
249  #endif
250  if(!alphaIsCompatible)
251  {
252  MappedDest(actualDestPtr, dest.size()).setZero();
253  compatibleAlpha = RhsScalar(1);
254  }
255  else
256  MappedDest(actualDestPtr, dest.size()) = dest;
257  }
258 
259  internal::triangular_matrix_vector_product
260  <Index,Mode,
261  LhsScalar, LhsBlasTraits::NeedToConjugate,
262  RhsScalar, RhsBlasTraits::NeedToConjugate,
263  ColMajor>
264  ::run(actualLhs.rows(),actualLhs.cols(),
265  actualLhs.data(),actualLhs.outerStride(),
266  actualRhs.data(),actualRhs.innerStride(),
267  actualDestPtr,1,compatibleAlpha);
268 
269  if (!evalToDest)
270  {
271  if(!alphaIsCompatible)
272  dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
273  else
274  dest = MappedDest(actualDestPtr, dest.size());
275  }
276 
277  if ( ((Mode&UnitDiag)==UnitDiag) && !numext::is_exactly_one(lhs_alpha) )
278  {
279  Index diagSize = (std::min)(lhs.rows(),lhs.cols());
280  dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
281  }
282  }
283 };
284 
285 template<int Mode> struct trmv_selector<Mode,RowMajor>
286 {
287  template<typename Lhs, typename Rhs, typename Dest>
288  static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
289  {
290  typedef typename Lhs::Scalar LhsScalar;
291  typedef typename Rhs::Scalar RhsScalar;
292  typedef typename Dest::Scalar ResScalar;
293 
294  typedef internal::blas_traits<Lhs> LhsBlasTraits;
295  typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
296  typedef internal::blas_traits<Rhs> RhsBlasTraits;
297  typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
298  typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
299 
300  std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
301  std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
302 
303  LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
304  RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
305  ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
306 
307  constexpr bool DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1;
308 
309  gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
310 
311  ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
312  DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
313 
314  if(!DirectlyUseRhs)
315  {
316  #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
317  Index size = actualRhs.size();
318  EIGEN_DENSE_STORAGE_CTOR_PLUGIN
319  #endif
320  Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
321  }
322 
323  internal::triangular_matrix_vector_product
324  <Index,Mode,
325  LhsScalar, LhsBlasTraits::NeedToConjugate,
326  RhsScalar, RhsBlasTraits::NeedToConjugate,
327  RowMajor>
328  ::run(actualLhs.rows(),actualLhs.cols(),
329  actualLhs.data(),actualLhs.outerStride(),
330  actualRhsPtr,1,
331  dest.data(),dest.innerStride(),
332  actualAlpha);
333 
334  if ( ((Mode&UnitDiag)==UnitDiag) && !numext::is_exactly_one(lhs_alpha) )
335  {
336  Index diagSize = (std::min)(lhs.rows(),lhs.cols());
337  dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
338  }
339  }
340 };
341 
342 } // end namespace internal
343 
344 } // end namespace Eigen
345 
346 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H
const ImagReturnType imag() const
#define EIGEN_DONT_INLINE
Definition: Macros.h:844
#define eigen_assert(x)
Definition: Macros.h:902
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:847
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
#define EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH
Definition: Settings.h:38
@ UnitDiag
Definition: Constants.h:215
@ ZeroDiag
Definition: Constants.h:217
@ Lower
Definition: Constants.h:211
@ Upper
Definition: Constants.h:213
@ AlignedMax
Definition: Constants.h:254
@ ColMajor
Definition: Constants.h:321
@ RowMajor
Definition: Constants.h:323
const unsigned int RowMajorBit
Definition: Constants.h:68
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:684
bool is_exactly_one(const X &x)
Definition: Meta.h:482
bool is_exactly_zero(const X &x)
Definition: Meta.h:475
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82