KLUSupport.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Kyle Macfarlan <kyle.macfarlan@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_KLUSUPPORT_H
11 #define EIGEN_KLUSUPPORT_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 /* TODO extract L, extract U, compute det, etc... */
18 
36 inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B [ ], klu_common *Common, double) {
37  return klu_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B, Common);
38 }
39 
40 inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double>B[], klu_common *Common, std::complex<double>) {
41  return klu_z_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), &numext::real_ref(B[0]), Common);
42 }
43 
44 inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[], klu_common *Common, double) {
45  return klu_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B, Common);
46 }
47 
48 inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double>B[], klu_common *Common, std::complex<double>) {
49  return klu_z_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), &numext::real_ref(B[0]), 0, Common);
50 }
51 
52 inline klu_numeric* klu_factor(int Ap [ ], int Ai [ ], double Ax [ ], klu_symbolic *Symbolic, klu_common *Common, double) {
53  return klu_factor(Ap, Ai, Ax, Symbolic, Common);
54 }
55 
56 inline klu_numeric* klu_factor(int Ap[], int Ai[], std::complex<double> Ax[], klu_symbolic *Symbolic, klu_common *Common, std::complex<double>) {
57  return klu_z_factor(Ap, Ai, &numext::real_ref(Ax[0]), Symbolic, Common);
58 }
59 
60 
61 template<typename MatrixType_>
62 class KLU : public SparseSolverBase<KLU<MatrixType_> >
63 {
64  protected:
67  public:
68  using Base::_solve_impl;
69  typedef MatrixType_ MatrixType;
70  typedef typename MatrixType::Scalar Scalar;
79  enum {
80  ColsAtCompileTime = MatrixType::ColsAtCompileTime,
81  MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
82  };
83 
84  public:
85 
86  KLU()
87  : m_dummy(0,0), mp_matrix(m_dummy)
88  {
89  init();
90  }
91 
92  template<typename InputMatrixType>
93  explicit KLU(const InputMatrixType& matrix)
94  : mp_matrix(matrix)
95  {
96  init();
97  compute(matrix);
98  }
99 
101  {
102  if(m_symbolic) klu_free_symbolic(&m_symbolic,&m_common);
103  if(m_numeric) klu_free_numeric(&m_numeric,&m_common);
104  }
105 
106  EIGEN_CONSTEXPR inline Index rows() const EIGEN_NOEXCEPT { return mp_matrix.rows(); }
107  EIGEN_CONSTEXPR inline Index cols() const EIGEN_NOEXCEPT { return mp_matrix.cols(); }
108 
115  {
116  eigen_assert(m_isInitialized && "Decomposition is not initialized.");
117  return m_info;
118  }
119 #if 0 // not implemented yet
120  inline const LUMatrixType& matrixL() const
121  {
122  if (m_extractedDataAreDirty) extractData();
123  return m_l;
124  }
125 
126  inline const LUMatrixType& matrixU() const
127  {
128  if (m_extractedDataAreDirty) extractData();
129  return m_u;
130  }
131 
132  inline const IntColVectorType& permutationP() const
133  {
134  if (m_extractedDataAreDirty) extractData();
135  return m_p;
136  }
137 
138  inline const IntRowVectorType& permutationQ() const
139  {
140  if (m_extractedDataAreDirty) extractData();
141  return m_q;
142  }
143 #endif
148  template<typename InputMatrixType>
149  void compute(const InputMatrixType& matrix)
150  {
151  if(m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
152  if(m_numeric) klu_free_numeric(&m_numeric, &m_common);
153  grab(matrix.derived());
155  factorize_impl();
156  }
157 
164  template<typename InputMatrixType>
165  void analyzePattern(const InputMatrixType& matrix)
166  {
167  if(m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
168  if(m_numeric) klu_free_numeric(&m_numeric, &m_common);
169 
170  grab(matrix.derived());
171 
173  }
174 
175 
180  inline const klu_common& kluCommon() const
181  {
182  return m_common;
183  }
184 
191  inline klu_common& kluCommon()
192  {
193  return m_common;
194  }
195 
202  template<typename InputMatrixType>
203  void factorize(const InputMatrixType& matrix)
204  {
205  eigen_assert(m_analysisIsOk && "KLU: you must first call analyzePattern()");
206  if(m_numeric)
207  klu_free_numeric(&m_numeric,&m_common);
208 
209  grab(matrix.derived());
210 
211  factorize_impl();
212  }
213 
215  template<typename BDerived,typename XDerived>
217 
218 #if 0 // not implemented yet
219  Scalar determinant() const;
220 
221  void extractData() const;
222 #endif
223 
224  protected:
225 
226  void init()
227  {
229  m_isInitialized = false;
230  m_numeric = 0;
231  m_symbolic = 0;
233 
234  klu_defaults(&m_common);
235  }
236 
238  {
240  m_analysisIsOk = false;
241  m_factorizationIsOk = false;
242  m_symbolic = klu_analyze(internal::convert_index<int>(mp_matrix.rows()),
243  const_cast<StorageIndex*>(mp_matrix.outerIndexPtr()), const_cast<StorageIndex*>(mp_matrix.innerIndexPtr()),
244  &m_common);
245  if (m_symbolic) {
246  m_isInitialized = true;
247  m_info = Success;
248  m_analysisIsOk = true;
250  }
251  }
252 
254  {
255 
256  m_numeric = klu_factor(const_cast<StorageIndex*>(mp_matrix.outerIndexPtr()), const_cast<StorageIndex*>(mp_matrix.innerIndexPtr()), const_cast<Scalar*>(mp_matrix.valuePtr()),
257  m_symbolic, &m_common, Scalar());
258 
259 
261  m_factorizationIsOk = m_numeric ? 1 : 0;
263  }
264 
265  template<typename MatrixDerived>
267  {
269  internal::construct_at(&mp_matrix, A.derived());
270  }
271 
272  void grab(const KLUMatrixRef &A)
273  {
274  if(&(A.derived()) != &mp_matrix)
275  {
278  }
279  }
280 
281  // cached data to reduce reallocation, etc.
282 #if 0 // not implemented yet
283  mutable LUMatrixType m_l;
284  mutable LUMatrixType m_u;
285  mutable IntColVectorType m_p;
286  mutable IntRowVectorType m_q;
287 #endif
288 
291 
292  klu_numeric* m_numeric;
293  klu_symbolic* m_symbolic;
294  klu_common m_common;
299 
300  private:
301  KLU(const KLU& ) { }
302 };
303 
304 #if 0 // not implemented yet
305 template<typename MatrixType>
306 void KLU<MatrixType>::extractData() const
307 {
308  if (m_extractedDataAreDirty)
309  {
310  eigen_assert(false && "KLU: extractData Not Yet Implemented");
311 
312  // get size of the data
313  int lnz, unz, rows, cols, nz_udiag;
314  umfpack_get_lunz(&lnz, &unz, &rows, &cols, &nz_udiag, m_numeric, Scalar());
315 
316  // allocate data
317  m_l.resize(rows,(std::min)(rows,cols));
318  m_l.resizeNonZeros(lnz);
319 
320  m_u.resize((std::min)(rows,cols),cols);
321  m_u.resizeNonZeros(unz);
322 
323  m_p.resize(rows);
324  m_q.resize(cols);
325 
326  // extract
327  umfpack_get_numeric(m_l.outerIndexPtr(), m_l.innerIndexPtr(), m_l.valuePtr(),
328  m_u.outerIndexPtr(), m_u.innerIndexPtr(), m_u.valuePtr(),
329  m_p.data(), m_q.data(), 0, 0, 0, m_numeric);
330 
331  m_extractedDataAreDirty = false;
332  }
333 }
334 
335 template<typename MatrixType>
336 typename KLU<MatrixType>::Scalar KLU<MatrixType>::determinant() const
337 {
338  eigen_assert(false && "KLU: extractData Not Yet Implemented");
339  return Scalar();
340 }
341 #endif
342 
343 template<typename MatrixType>
344 template<typename BDerived,typename XDerived>
346 {
347  Index rhsCols = b.cols();
348  EIGEN_STATIC_ASSERT((XDerived::Flags&RowMajorBit)==0, THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
349  eigen_assert(m_factorizationIsOk && "The decomposition is not in a valid state for solving, you must first call either compute() or analyzePattern()/factorize()");
350 
351  x = b;
352  int info = klu_solve(m_symbolic, m_numeric, b.rows(), rhsCols, x.const_cast_derived().data(), const_cast<klu_common*>(&m_common), Scalar());
353 
354  m_info = info!=0 ? Success : NumericalIssue;
355  return true;
356 }
357 
358 } // end namespace Eigen
359 
360 #endif // EIGEN_KLUSUPPORT_H
Array< int, 3, 1 > b
MatrixXcf A
MatrixXf B
#define EIGEN_NOEXCEPT
Definition: Macros.h:1260
#define EIGEN_CONSTEXPR
Definition: Macros.h:747
#define eigen_assert(x)
Definition: Macros.h:902
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
NumTraits< Scalar >::Real RealScalar
Definition: DenseBase.h:68
internal::traits< Derived >::StorageIndex StorageIndex
The type used to store indices.
Definition: DenseBase.h:58
internal::traits< Derived >::Scalar Scalar
Definition: DenseBase.h:61
MatrixType_ MatrixType
Definition: KLUSupport.h:69
klu_symbolic * m_symbolic
Definition: KLUSupport.h:293
EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
Definition: KLUSupport.h:106
const klu_common & kluCommon() const
Definition: KLUSupport.h:180
void compute(const InputMatrixType &matrix)
Definition: KLUSupport.h:149
@ ColsAtCompileTime
Definition: KLUSupport.h:80
@ MaxColsAtCompileTime
Definition: KLUSupport.h:81
KLUMatrixRef mp_matrix
Definition: KLUSupport.h:290
void factorize(const InputMatrixType &matrix)
Definition: KLUSupport.h:203
void factorize_impl()
Definition: KLUSupport.h:253
int m_analysisIsOk
Definition: KLUSupport.h:297
SparseSolverBase< KLU< MatrixType_ > > Base
Definition: KLUSupport.h:65
void grab(const EigenBase< MatrixDerived > &A)
Definition: KLUSupport.h:266
Ref< const KLUMatrixType, StandardCompressedFormat > KLUMatrixRef
Definition: KLUSupport.h:78
KLU(const KLU &)
Definition: KLUSupport.h:301
SparseMatrix< Scalar, ColMajor, int > KLUMatrixType
Definition: KLUSupport.h:77
EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
Definition: KLUSupport.h:107
klu_numeric * m_numeric
Definition: KLUSupport.h:292
MatrixType::StorageIndex StorageIndex
Definition: KLUSupport.h:72
ComputationInfo info() const
Reports whether previous computation was successful.
Definition: KLUSupport.h:114
KLUMatrixType m_dummy
Definition: KLUSupport.h:289
KLU(const InputMatrixType &matrix)
Definition: KLUSupport.h:93
bool _solve_impl(const MatrixBase< BDerived > &b, MatrixBase< XDerived > &x) const
Definition: KLUSupport.h:345
MatrixType::RealScalar RealScalar
Definition: KLUSupport.h:71
bool m_extractedDataAreDirty
Definition: KLUSupport.h:298
void analyzePattern_impl()
Definition: KLUSupport.h:237
int m_factorizationIsOk
Definition: KLUSupport.h:296
Matrix< Scalar, Dynamic, 1 > Vector
Definition: KLUSupport.h:73
Matrix< int, 1, MatrixType::ColsAtCompileTime > IntRowVectorType
Definition: KLUSupport.h:74
ComputationInfo m_info
Definition: KLUSupport.h:295
klu_common m_common
Definition: KLUSupport.h:294
klu_common & kluCommon()
Definition: KLUSupport.h:191
Matrix< int, MatrixType::RowsAtCompileTime, 1 > IntColVectorType
Definition: KLUSupport.h:75
MatrixType::Scalar Scalar
Definition: KLUSupport.h:70
SparseMatrix< Scalar > LUMatrixType
Definition: KLUSupport.h:76
void grab(const KLUMatrixRef &A)
Definition: KLUSupport.h:272
void init()
Definition: KLUSupport.h:226
void analyzePattern(const InputMatrixType &matrix)
Definition: KLUSupport.h:165
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:52
A base class for sparse solvers.
int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[], klu_common *Common, double)
A sparse LU factorization and solver based on KLU.
Definition: KLUSupport.h:36
ComputationInfo
Definition: Constants.h:444
@ NumericalIssue
Definition: Constants.h:448
@ InvalidInput
Definition: Constants.h:453
@ Success
Definition: Constants.h:446
const unsigned int RowMajorBit
Definition: Constants.h:68
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:684
void destroy_at(T *p)
Definition: Memory.h:1264
T * construct_at(T *p, Args &&... args)
Definition: Memory.h:1248
internal::add_const_on_value_type_t< EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) > real_ref(const Scalar &x)
: InteropHeaders
Definition: Core:139
int umfpack_get_numeric(int Lp[], int Lj[], double Lx[], int Up[], int Ui[], double Ux[], int P[], int Q[], double Dx[], int *do_recip, double Rs[], void *Numeric)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
klu_numeric * klu_factor(int Ap[], int Ai[], double Ax[], klu_symbolic *Symbolic, klu_common *Common, double)
Definition: KLUSupport.h:52
int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[], klu_common *Common, double)
Definition: KLUSupport.h:44
int umfpack_get_lunz(int *lnz, int *unz, int *n_row, int *n_col, int *nz_udiag, void *Numeric, double)