1 #ifndef EIGEN_ACCELERATESUPPORT_H
2 #define EIGEN_ACCELERATESUPPORT_H
4 #include <Accelerate/Accelerate.h>
10 template <
typename MatrixType_,
int UpLo_, SparseFactorization_t Solver_,
bool EnforceSquare_>
24 template <
typename MatrixType,
int UpLo = Lower>
38 template <
typename MatrixType,
int UpLo = Lower>
52 template <
typename MatrixType,
int UpLo = Lower>
66 template <
typename MatrixType,
int UpLo = Lower>
80 template <
typename MatrixType,
int UpLo = Lower>
93 template <
typename MatrixType>
106 template <
typename MatrixType>
110 template <
typename T>
111 struct AccelFactorizationDeleter {
121 template <
typename DenseVecT,
typename DenseMatT,
typename SparseMatT,
typename NumFactT>
122 struct SparseTypesTraitBase {
123 typedef DenseVecT AccelDenseVector;
124 typedef DenseMatT AccelDenseMatrix;
125 typedef SparseMatT AccelSparseMatrix;
127 typedef SparseOpaqueSymbolicFactorization SymbolicFactorization;
128 typedef NumFactT NumericFactorization;
130 typedef AccelFactorizationDeleter<SymbolicFactorization> SymbolicFactorizationDeleter;
131 typedef AccelFactorizationDeleter<NumericFactorization> NumericFactorizationDeleter;
134 template <
typename Scalar>
135 struct SparseTypesTrait {};
138 struct SparseTypesTrait<double> : SparseTypesTraitBase<DenseVector_Double, DenseMatrix_Double, SparseMatrix_Double,
139 SparseOpaqueFactorization_Double> {};
142 struct SparseTypesTrait<float>
143 : SparseTypesTraitBase<DenseVector_Float, DenseMatrix_Float, SparseMatrix_Float, SparseOpaqueFactorization_Float> {
148 template <
typename MatrixType_,
int UpLo_, SparseFactorization_t Solver_,
bool EnforceSquare_>
156 using Base::_solve_impl;
159 typedef typename MatrixType::Scalar
Scalar;
175 auto check_flag_set = [](
int value,
int flag) {
return ((value & flag) == flag); };
179 m_triType = (UpLo_ &
Lower) ? SparseLowerTriangle : SparseUpperTriangle;
180 }
else if (check_flag_set(UpLo_,
UnitLower)) {
183 }
else if (check_flag_set(UpLo_,
UnitUpper)) {
192 }
else if (check_flag_set(UpLo_,
Lower)) {
195 }
else if (check_flag_set(UpLo_,
Upper)) {
200 m_triType = (UpLo_ &
Lower) ? SparseLowerTriangle : SparseUpperTriangle;
224 template <
typename Rhs,
typename Dest>
231 template <
typename T>
233 const Index nColumnsStarts =
a.cols() + 1;
235 columnStarts.resize(nColumnsStarts);
237 for (
Index i = 0;
i < nColumnsStarts;
i++) columnStarts[
i] =
a.outerIndexPtr()[
i];
239 SparseAttributes_t attributes{};
240 attributes.transpose =
false;
244 SparseMatrixStructure structure{};
245 structure.attributes = attributes;
246 structure.rowCount =
static_cast<int>(
a.rows());
247 structure.columnCount =
static_cast<int>(
a.cols());
248 structure.blockSize = 1;
249 structure.columnStarts = columnStarts.data();
250 structure.rowIndices =
const_cast<int*
>(
a.innerIndexPtr());
252 A.structure = structure;
253 A.data =
const_cast<T*
>(
a.valuePtr());
259 SparseSymbolicFactorOptions opts{};
260 opts.control = SparseDefaultControl;
262 opts.order =
nullptr;
263 opts.ignoreRowsAndColumns =
nullptr;
264 opts.malloc = malloc;
266 opts.reportError =
nullptr;
278 SparseStatus_t status = SparseStatusReleased;
297 case SparseFactorizationFailed:
298 case SparseMatrixIsSingular:
301 case SparseInternalError:
302 case SparseParameterError:
303 case SparseStatusReleased:
320 template <
typename MatrixType_,
int UpLo_, SparseFactorization_t Solver_,
bool EnforceSquare_>
328 std::vector<long> columnStarts;
330 buildAccelSparseMatrix(
a,
A, columnStarts);
334 if (m_symbolicFactorization) doFactorization(
A);
336 m_isInitialized =
true;
345 template <
typename MatrixType_,
int UpLo_, SparseFactorization_t Solver_,
bool EnforceSquare_>
353 std::vector<long> columnStarts;
355 buildAccelSparseMatrix(
a,
A, columnStarts);
359 m_isInitialized =
true;
368 template <
typename MatrixType_,
int UpLo_, SparseFactorization_t Solver_,
bool EnforceSquare_>
370 eigen_assert(m_symbolicFactorization &&
"You must first call analyzePattern()");
376 std::vector<long> columnStarts;
378 buildAccelSparseMatrix(
a,
A, columnStarts);
383 template <
typename MatrixType_,
int UpLo_, SparseFactorization_t Solver_,
bool EnforceSquare_>
384 template <
typename Rhs,
typename Dest>
387 if (!m_numericFactorization) {
395 SparseStatus_t status = SparseStatusOK;
401 xmat.attributes = SparseAttributes_t();
402 xmat.columnCount =
static_cast<int>(
x.cols());
403 xmat.rowCount =
static_cast<int>(
x.rows());
404 xmat.columnStride = xmat.rowCount;
408 bmat.attributes = SparseAttributes_t();
409 bmat.columnCount =
static_cast<int>(
b.cols());
410 bmat.rowCount =
static_cast<int>(
b.rows());
411 bmat.columnStride = bmat.rowCount;
414 SparseSolve(*m_numericFactorization, bmat, xmat);
416 updateInfoStatus(status);
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
Matrix< float, 1, Dynamic > MatrixType
typename internal::SparseTypesTrait< Scalar >::SymbolicFactorizationDeleter SymbolicFactorizationDeleter
MatrixType::StorageIndex StorageIndex
typename internal::SparseTypesTrait< Scalar >::SymbolicFactorization SymbolicFactorization
std::unique_ptr< SymbolicFactorization, SymbolicFactorizationDeleter > m_symbolicFactorization
SparseKind_t m_sparseKind
void analyzePattern(const MatrixType &matrix)
ComputationInfo info() const
std::unique_ptr< NumericFactorization, NumericFactorizationDeleter > m_numericFactorization
void doFactorization(AccelSparseMatrix &A)
typename internal::SparseTypesTrait< Scalar >::AccelDenseVector AccelDenseVector
void factorize(const MatrixType &matrix)
typename internal::SparseTypesTrait< Scalar >::AccelSparseMatrix AccelSparseMatrix
typename internal::SparseTypesTrait< Scalar >::NumericFactorizationDeleter NumericFactorizationDeleter
void doAnalysis(AccelSparseMatrix &A)
typename internal::SparseTypesTrait< Scalar >::AccelDenseMatrix AccelDenseMatrix
MatrixType::Scalar Scalar
SparseTriangle_t m_triType
AccelerateImpl(const MatrixType &matrix)
void setOrder(SparseOrder_t order)
typename internal::SparseTypesTrait< Scalar >::NumericFactorization NumericFactorization
void buildAccelSparseMatrix(const SparseMatrix< T > &a, AccelSparseMatrix &A, std::vector< long > &columnStarts)
void compute(const MatrixType &matrix)
void _solve_impl(const MatrixBase< Rhs > &b, MatrixBase< Dest > &dest) const
void updateInfoStatus(SparseStatus_t status) const
Base class for all dense matrices, vectors, and expressions.
EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
A versatible sparse matrix representation.
A base class for sparse solvers.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.