KroneckerTensorProduct.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
5 // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
6 // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef KRONECKER_TENSOR_PRODUCT_H
13 #define KRONECKER_TENSOR_PRODUCT_H
14 
15 #include "./InternalHeaderCheck.h"
16 
17 namespace Eigen {
18 
26 template<typename Derived>
27 class KroneckerProductBase : public ReturnByValue<Derived>
28 {
29  private:
30  typedef typename internal::traits<Derived> Traits;
31  typedef typename Traits::Scalar Scalar;
32 
33  protected:
34  typedef typename Traits::Lhs Lhs;
35  typedef typename Traits::Rhs Rhs;
36 
37  public:
39  KroneckerProductBase(const Lhs& A, const Rhs& B)
40  : m_A(A), m_B(B)
41  {}
42 
43  inline Index rows() const { return m_A.rows() * m_B.rows(); }
44  inline Index cols() const { return m_A.cols() * m_B.cols(); }
45 
50  Scalar coeff(Index row, Index col) const
51  {
52  return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
53  m_B.coeff(row % m_B.rows(), col % m_B.cols());
54  }
55 
60  Scalar coeff(Index i) const
61  {
63  return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
64  }
65 
66  protected:
67  typename Lhs::Nested m_A;
68  typename Rhs::Nested m_B;
69 };
70 
83 template<typename Lhs, typename Rhs>
84 class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
85 {
86  private:
88  using Base::m_A;
89  using Base::m_B;
90 
91  public:
93  KroneckerProduct(const Lhs& A, const Rhs& B)
94  : Base(A, B)
95  {}
96 
98  template<typename Dest> void evalTo(Dest& dst) const;
99 };
100 
116 template<typename Lhs, typename Rhs>
117 class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
118 {
119  private:
121  using Base::m_A;
122  using Base::m_B;
123 
124  public:
126  KroneckerProductSparse(const Lhs& A, const Rhs& B)
127  : Base(A, B)
128  {}
129 
131  template<typename Dest> void evalTo(Dest& dst) const;
132 };
133 
134 template<typename Lhs, typename Rhs>
135 template<typename Dest>
137 {
138  const int BlockRows = Rhs::RowsAtCompileTime,
139  BlockCols = Rhs::ColsAtCompileTime;
140  const Index Br = m_B.rows(),
141  Bc = m_B.cols();
142  for (Index i=0; i < m_A.rows(); ++i)
143  for (Index j=0; j < m_A.cols(); ++j)
144  Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B;
145 }
146 
147 template<typename Lhs, typename Rhs>
148 template<typename Dest>
150 {
151  Index Br = m_B.rows(), Bc = m_B.cols();
152  dst.resize(this->rows(), this->cols());
153  dst.resizeNonZeros(0);
154 
155  // 1 - evaluate the operands if needed:
156  typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1;
157  typedef internal::remove_all_t<Lhs1> Lhs1Cleaned;
158  const Lhs1 lhs1(m_A);
159  typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1;
160  typedef internal::remove_all_t<Rhs1> Rhs1Cleaned;
161  const Rhs1 rhs1(m_B);
162 
163  // 2 - construct respective iterators
164  typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator;
165  typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator;
166 
167  // compute number of non-zeros per innervectors of dst
168  {
169  // TODO VectorXi is not necessarily big enough!
170  VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
171  for (Index kA=0; kA < m_A.outerSize(); ++kA)
172  for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
173  nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
174 
175  VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
176  for (Index kB=0; kB < m_B.outerSize(); ++kB)
177  for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
178  nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
179 
180  Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
181  dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
182  }
183 
184  for (Index kA=0; kA < m_A.outerSize(); ++kA)
185  {
186  for (Index kB=0; kB < m_B.outerSize(); ++kB)
187  {
188  for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
189  {
190  for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
191  {
192  Index i = itA.row() * Br + itB.row(),
193  j = itA.col() * Bc + itB.col();
194  dst.insert(i,j) = itA.value() * itB.value();
195  }
196  }
197  }
198  }
199 }
200 
201 namespace internal {
202 
203 template<typename Lhs_, typename Rhs_>
204 struct traits<KroneckerProduct<Lhs_,Rhs_> >
205 {
206  typedef remove_all_t<Lhs_> Lhs;
207  typedef remove_all_t<Rhs_> Rhs;
209  typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
210 
211  enum {
216  };
217 
218  typedef Matrix<Scalar,Rows,Cols> ReturnType;
219 };
220 
221 template<typename Lhs_, typename Rhs_>
222 struct traits<KroneckerProductSparse<Lhs_,Rhs_> >
223 {
224  typedef MatrixXpr XprKind;
225  typedef remove_all_t<Lhs_> Lhs;
226  typedef remove_all_t<Rhs_> Rhs;
227  typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
228  typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind, scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret StorageKind;
229  typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
230 
231  enum {
232  LhsFlags = Lhs::Flags,
233  RhsFlags = Rhs::Flags,
234 
239 
240  EvalToRowMajor = (int(LhsFlags) & int(RhsFlags) & RowMajorBit),
241  RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
242 
243  Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & RemovedBits)
245  CoeffReadCost = HugeCost
246  };
247 
248  typedef SparseMatrix<Scalar, 0, StorageIndex> ReturnType;
249 };
250 
251 } // end namespace internal
252 
272 template<typename A, typename B>
274 {
275  return KroneckerProduct<A, B>(a.derived(), b.derived());
276 }
277 
299 template<typename A, typename B>
301 {
302  return KroneckerProductSparse<A,B>(a.derived(), b.derived());
303 }
304 
305 } // end namespace Eigen
306 
307 #endif // KRONECKER_TENSOR_PRODUCT_H
ArrayXXi a
SparseMatrix< double > A(n, n)
int i
RowXpr row(Index i) const
ColXpr col(Index i) const
MatrixXf B
#define EIGEN_STATIC_ASSERT_VECTOR_ONLY(TYPE)
static const ConstantReturnType Zero()
TransposeReturnType transpose()
EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT
The base class of dense and sparse Kronecker product.
Scalar coeff(Index row, Index col) const
KroneckerProductBase(const Lhs &A, const Rhs &B)
Constructor.
internal::traits< Derived > Traits
Kronecker tensor product helper class for sparse matrices.
KroneckerProductBase< KroneckerProductSparse > Base
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
KroneckerProductSparse(const Lhs &A, const Rhs &B)
Constructor.
Kronecker tensor product helper class for dense matrices.
KroneckerProduct(const Lhs &A, const Rhs &B)
Constructor.
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
KroneckerProductBase< KroneckerProduct > Base
internal::dense_xpr_base< ReturnByValue >::type Base
KroneckerProduct< A, B > kroneckerProduct(const MatrixBase< A > &a, const MatrixBase< B > &b)
const unsigned int EvalBeforeNestingBit
const unsigned int RowMajorBit
constexpr int size_at_compile_time(int rows, int cols)
typename remove_all< T >::type remove_all_t
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
const unsigned int HereditaryBits
const int HugeCost
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
Derived::Index cols
Derived::Index rows
std::ptrdiff_t j