10 #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
11 #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
36 template<
typename SparseXprType,
typename DiagonalCoeffType,
int SDP_Tag>
37 struct sparse_diagonal_product_evaluator;
39 template<
typename Lhs,
typename Rhs,
int ProductTag>
41 :
public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
46 typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
47 explicit product_evaluator(
const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
50 template<
typename Lhs,
typename Rhs,
int ProductTag>
51 struct product_evaluator<Product<Lhs, Rhs,
DefaultProduct>, ProductTag, SparseShape, DiagonalShape>
52 :
public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
54 typedef Product<Lhs, Rhs, DefaultProduct> XprType;
58 explicit product_evaluator(
const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal().transpose()) {}
61 template<
typename SparseXprType,
typename DiagonalCoeffType>
62 struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType,
SDP_AsScalarProduct>
65 typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
66 typedef typename SparseXprType::Scalar Scalar;
69 class InnerIterator :
public SparseXprInnerIterator
72 InnerIterator(
const sparse_diagonal_product_evaluator &xprEval,
Index outer)
73 : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
74 m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
77 EIGEN_STRONG_INLINE Scalar value()
const {
return m_coeff * SparseXprInnerIterator::value(); }
79 typename DiagonalCoeffType::Scalar m_coeff;
82 sparse_diagonal_product_evaluator(
const SparseXprType &sparseXpr,
const DiagonalCoeffType &diagCoeff)
83 : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
86 Index nonZerosEstimate()
const {
return m_sparseXprImpl.nonZerosEstimate(); }
89 evaluator<SparseXprType> m_sparseXprImpl;
90 evaluator<DiagonalCoeffType> m_diagCoeffImpl;
94 template<
typename SparseXprType,
typename DiagCoeffType>
95 struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType,
SDP_AsCwiseProduct>
97 typedef typename SparseXprType::Scalar Scalar;
98 typedef typename SparseXprType::StorageIndex StorageIndex;
100 typedef typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
101 : SparseXprType::ColsAtCompileTime>::type DiagCoeffNested;
105 typedef typename evaluator<SparseXprType>::InnerIterator SparseXprIter;
107 InnerIterator(
const sparse_diagonal_product_evaluator &xprEval,
Index outer)
108 : m_sparseIter(xprEval.m_sparseXprEval, outer), m_diagCoeffNested(xprEval.m_diagCoeffNested)
111 inline Scalar value()
const {
return m_sparseIter.value() * m_diagCoeffNested.coeff(index()); }
112 inline StorageIndex index()
const {
return m_sparseIter.index(); }
113 inline Index outer()
const {
return m_sparseIter.outer(); }
114 inline Index col()
const {
return SparseXprType::IsRowMajor ? m_sparseIter.index() : m_sparseIter.outer(); }
115 inline Index row()
const {
return SparseXprType::IsRowMajor ? m_sparseIter.outer() : m_sparseIter.index(); }
117 EIGEN_STRONG_INLINE InnerIterator&
operator++() { ++m_sparseIter;
return *
this; }
118 inline operator bool()
const {
return m_sparseIter; }
121 SparseXprIter m_sparseIter;
122 DiagCoeffNested m_diagCoeffNested;
125 sparse_diagonal_product_evaluator(
const SparseXprType &sparseXpr,
const DiagCoeffType &diagCoeff)
126 : m_sparseXprEval(sparseXpr), m_diagCoeffNested(diagCoeff)
129 Index nonZerosEstimate()
const {
return m_sparseXprEval.nonZerosEstimate(); }
132 evaluator<SparseXprType> m_sparseXprEval;
133 DiagCoeffNested m_diagCoeffNested;
RowXpr row(Index i)
This is the const version of row(). */.
ColXpr col(Index i)
This is the const version of col().
Expression of the product of two arbitrary matrices or vectors.
const unsigned int RowMajorBit
bfloat16 operator++(bfloat16 &a)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.