SolveTriangular.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SOLVETRIANGULAR_H
11 #define EIGEN_SOLVETRIANGULAR_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 namespace internal {
18 
19 // Forward declarations:
20 // The following two routines are implemented in the products/TriangularSolver*.h files
21 template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
22 struct triangular_solve_vector;
23 
24 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder, int OtherInnerStride>
25 struct triangular_solve_matrix;
26 
27 // small helper struct extracting some traits on the underlying solver operation
28 template<typename Lhs, typename Rhs, int Side>
29 class trsolve_traits
30 {
31  private:
32  enum {
33  RhsIsVectorAtCompileTime = (Side==OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime)==1
34  };
35  public:
36  enum {
37  Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime != Dynamic && Rhs::SizeAtCompileTime <= 8)
39  RhsVectors = RhsIsVectorAtCompileTime ? 1 : Dynamic
40  };
41 };
42 
43 template<typename Lhs, typename Rhs,
44  int Side, // can be OnTheLeft/OnTheRight
45  int Mode, // can be Upper/Lower | UnitDiag
46  int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
47  int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
48  >
49 struct triangular_solver_selector;
50 
51 template<typename Lhs, typename Rhs, int Side, int Mode>
52 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
53 {
54  typedef typename Lhs::Scalar LhsScalar;
55  typedef typename Rhs::Scalar RhsScalar;
56  typedef blas_traits<Lhs> LhsProductTraits;
57  typedef typename LhsProductTraits::ExtractType ActualLhsType;
58  typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
59  static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
60  {
61  ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
62 
63  // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
64 
65  bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1;
66 
67  ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhs,rhs.size(),
68  (useRhsDirectly ? rhs.data() : 0));
69 
70  if(!useRhsDirectly)
71  MappedRhs(actualRhs,rhs.size()) = rhs;
72 
73  triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate,
74  (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>
75  ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
76 
77  if(!useRhsDirectly)
78  rhs = MappedRhs(actualRhs, rhs.size());
79  }
80 };
81 
82 // the rhs is a matrix
83 template<typename Lhs, typename Rhs, int Side, int Mode>
84 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
85 {
86  typedef typename Rhs::Scalar Scalar;
87  typedef blas_traits<Lhs> LhsProductTraits;
88  typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
89 
90  static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
91  {
92  add_const_on_value_type_t<ActualLhsType> actualLhs = LhsProductTraits::extract(lhs);
93 
94  const Index size = lhs.rows();
95  const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows();
96 
97  typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
98  Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType;
99 
100  BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false);
101 
102  triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
103  (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor, Rhs::InnerStrideAtCompileTime>
104  ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.innerStride(), rhs.outerStride(), blocking);
105  }
106 };
107 
108 
112 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size,
113  bool Stop = LoopIndex==Size>
114 struct triangular_solver_unroller;
115 
116 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size>
117 struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,false> {
118  enum {
119  IsLower = ((Mode&Lower)==Lower),
120  DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1,
121  StartIndex = IsLower ? 0 : DiagIndex+1
122  };
123  static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
124  {
125  if (LoopIndex>0)
126  rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex).template segment<LoopIndex>(StartIndex).transpose()
127  .cwiseProduct(rhs.template segment<LoopIndex>(StartIndex)).sum();
128 
129  if(!(Mode & UnitDiag))
130  rhs.coeffRef(DiagIndex) /= lhs.coeff(DiagIndex,DiagIndex);
131 
132  triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex+1,Size>::run(lhs,rhs);
133  }
134 };
135 
136 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size>
137 struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,true> {
138  static EIGEN_DEVICE_FUNC void run(const Lhs&, Rhs&) {}
139 };
140 
141 template<typename Lhs, typename Rhs, int Mode>
142 struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
143  static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
144  { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
145 };
146 
147 template<typename Lhs, typename Rhs, int Mode>
148 struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
149  static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
150  {
151  Transpose<const Lhs> trLhs(lhs);
152  Transpose<Rhs> trRhs(rhs);
153 
154  triangular_solver_unroller<Transpose<const Lhs>,Transpose<Rhs>,
155  ((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
156  0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs);
157  }
158 };
159 
160 } // end namespace internal
161 
162 
166 #ifndef EIGEN_PARSED_BY_DOXYGEN
167 template<typename MatrixType, unsigned int Mode>
168 template<int Side, typename OtherDerived>
169 EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
170 {
171  OtherDerived& other = _other.const_cast_derived();
172  eigen_assert( derived().cols() == derived().rows() && ((Side==OnTheLeft && derived().cols() == other.rows()) || (Side==OnTheRight && derived().cols() == other.cols())) );
173  eigen_assert((!(int(Mode) & int(ZeroDiag))) && bool(int(Mode) & (int(Upper) | int(Lower))));
174  // If solving for a 0x0 matrix, nothing to do, simply return.
175  if (derived().cols() == 0)
176  return;
177 
178  enum { copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime!=1};
179  typedef std::conditional_t<copy,
180  typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&> OtherCopy;
181  OtherCopy otherCopy(other);
182 
183  internal::triangular_solver_selector<MatrixType, std::remove_reference_t<OtherCopy>,
184  Side, Mode>::run(derived().nestedExpression(), otherCopy);
185 
186  if (copy)
187  other = otherCopy;
188 }
189 
190 template<typename Derived, unsigned int Mode>
191 template<int Side, typename Other>
192 const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other>
193 TriangularViewImpl<Derived,Mode,Dense>::solve(const MatrixBase<Other>& other) const
194 {
195  return internal::triangular_solve_retval<Side,TriangularViewType,Other>(derived(), other.derived());
196 }
197 #endif
198 
199 namespace internal {
200 
201 
202 template<int Side, typename TriangularType, typename Rhs>
203 struct traits<triangular_solve_retval<Side, TriangularType, Rhs> >
204 {
205  typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType;
206 };
207 
208 template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval
209  : public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> >
210 {
211  typedef remove_all_t<typename Rhs::Nested> RhsNestedCleaned;
212  typedef ReturnByValue<triangular_solve_retval> Base;
213 
214  triangular_solve_retval(const TriangularType& tri, const Rhs& rhs)
215  : m_triangularMatrix(tri), m_rhs(rhs)
216  {}
217 
218  inline EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_rhs.rows(); }
219  inline EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); }
220 
221  template<typename Dest> inline void evalTo(Dest& dst) const
222  {
223  if(!is_same_dense(dst,m_rhs))
224  dst = m_rhs;
225  m_triangularMatrix.template solveInPlace<Side>(dst);
226  }
227 
228  protected:
229  const TriangularType& m_triangularMatrix;
230  typename Rhs::Nested m_rhs;
231 };
232 
233 } // namespace internal
234 
235 } // end namespace Eigen
236 
237 #endif // EIGEN_SOLVETRIANGULAR_H
#define EIGEN_NOEXCEPT
Definition: Macros.h:1260
#define EIGEN_CONSTEXPR
Definition: Macros.h:747
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:883
#define eigen_assert(x)
Definition: Macros.h:902
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:847
Tridiagonalization< MatrixXf > tri
@ UnitDiag
Definition: Constants.h:215
@ ZeroDiag
Definition: Constants.h:217
@ Lower
Definition: Constants.h:211
@ Upper
Definition: Constants.h:213
@ Aligned
Definition: Constants.h:242
@ ColMajor
Definition: Constants.h:321
@ RowMajor
Definition: Constants.h:323
@ OnTheLeft
Definition: Constants.h:334
@ OnTheRight
Definition: Constants.h:336
const unsigned int RowMajorBit
Definition: Constants.h:68
bool is_same_dense(const T1 &mat1, const T2 &mat2, std::enable_if_t< possibly_same_dense< T1, T2 >::value > *=0)
Definition: XprHelper.h:745
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
@ CompleteUnrolling
Definition: Constants.h:306
@ NoUnrolling
Definition: Constants.h:301
const int Dynamic
Definition: Constants.h:24