SparseDiagonalProduct.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
11 #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 // The product of a diagonal matrix with a sparse matrix can be easily
18 // implemented using expression template.
19 // We have two consider very different cases:
20 // 1 - diag * row-major sparse
21 // => each inner vector <=> scalar * sparse vector product
22 // => so we can reuse CwiseUnaryOp::InnerIterator
23 // 2 - diag * col-major sparse
24 // => each inner vector <=> densevector * sparse vector cwise product
25 // => again, we can reuse specialization of CwiseBinaryOp::InnerIterator
26 // for that particular case
27 // The two other cases are symmetric.
28 
29 namespace internal {
30 
31 enum {
34 };
35 
36 template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
37 struct sparse_diagonal_product_evaluator;
38 
39 template<typename Lhs, typename Rhs, int ProductTag>
40 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DiagonalShape, SparseShape>
41  : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
42 {
43  typedef Product<Lhs, Rhs, DefaultProduct> XprType;
44  enum { CoeffReadCost = HugeCost, Flags = Rhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags
45 
46  typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
47  explicit product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
48 };
49 
50 template<typename Lhs, typename Rhs, int ProductTag>
51 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DiagonalShape>
52  : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
53 {
54  typedef Product<Lhs, Rhs, DefaultProduct> XprType;
55  enum { CoeffReadCost = HugeCost, Flags = Lhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags
56 
57  typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
58  explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal().transpose()) {}
59 };
60 
61 template<typename SparseXprType, typename DiagonalCoeffType>
62 struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
63 {
64 protected:
65  typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
66  typedef typename SparseXprType::Scalar Scalar;
67 
68 public:
69  class InnerIterator : public SparseXprInnerIterator
70  {
71  public:
72  InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
73  : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
74  m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
75  {}
76 
77  EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
78  protected:
79  typename DiagonalCoeffType::Scalar m_coeff;
80  };
81 
82  sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
83  : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
84  {}
85 
86  Index nonZerosEstimate() const { return m_sparseXprImpl.nonZerosEstimate(); }
87 
88 protected:
89  evaluator<SparseXprType> m_sparseXprImpl;
90  evaluator<DiagonalCoeffType> m_diagCoeffImpl;
91 };
92 
93 
94 template<typename SparseXprType, typename DiagCoeffType>
95 struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
96 {
97  typedef typename SparseXprType::Scalar Scalar;
98  typedef typename SparseXprType::StorageIndex StorageIndex;
99 
100  typedef typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
101  : SparseXprType::ColsAtCompileTime>::type DiagCoeffNested;
102 
103  class InnerIterator
104  {
105  typedef typename evaluator<SparseXprType>::InnerIterator SparseXprIter;
106  public:
107  InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
108  : m_sparseIter(xprEval.m_sparseXprEval, outer), m_diagCoeffNested(xprEval.m_diagCoeffNested)
109  {}
110 
111  inline Scalar value() const { return m_sparseIter.value() * m_diagCoeffNested.coeff(index()); }
112  inline StorageIndex index() const { return m_sparseIter.index(); }
113  inline Index outer() const { return m_sparseIter.outer(); }
114  inline Index col() const { return SparseXprType::IsRowMajor ? m_sparseIter.index() : m_sparseIter.outer(); }
115  inline Index row() const { return SparseXprType::IsRowMajor ? m_sparseIter.outer() : m_sparseIter.index(); }
116 
117  EIGEN_STRONG_INLINE InnerIterator& operator++() { ++m_sparseIter; return *this; }
118  inline operator bool() const { return m_sparseIter; }
119 
120  protected:
121  SparseXprIter m_sparseIter;
122  DiagCoeffNested m_diagCoeffNested;
123  };
124 
125  sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
126  : m_sparseXprEval(sparseXpr), m_diagCoeffNested(diagCoeff)
127  {}
128 
129  Index nonZerosEstimate() const { return m_sparseXprEval.nonZerosEstimate(); }
130 
131 protected:
132  evaluator<SparseXprType> m_sparseXprEval;
133  DiagCoeffNested m_diagCoeffNested;
134 };
135 
136 } // end namespace internal
137 
138 } // end namespace Eigen
139 
140 #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H
RowXpr row(Index i)
This is the const version of row(). *‍/.
ColXpr col(Index i)
This is the const version of col().
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:77
const unsigned int RowMajorBit
Definition: Constants.h:68
bfloat16 operator++(bfloat16 &a)
Definition: BFloat16.h:298
: InteropHeaders
Definition: Core:139
@ DefaultProduct
Definition: Constants.h:504
const int HugeCost
Definition: Constants.h:46
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82