10 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
11 #define EIGEN_SPARSEDENSEPRODUCT_H
19 template <>
struct product_promote_storage_type<Sparse,Dense,
OuterProduct> {
typedef Sparse ret; };
20 template <>
struct product_promote_storage_type<Dense,Sparse,
OuterProduct> {
typedef Sparse ret; };
22 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
25 bool ColPerCol = ((DenseRhsType::Flags&
RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
26 struct sparse_time_dense_product_impl;
28 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
29 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar,
RowMajor, true>
31 typedef internal::remove_all_t<SparseLhsType> Lhs;
32 typedef internal::remove_all_t<DenseRhsType> Rhs;
33 typedef internal::remove_all_t<DenseResType> Res;
34 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
35 typedef evaluator<Lhs> LhsEval;
36 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType&
res,
const typename Res::Scalar& alpha)
41 #ifdef EIGEN_HAS_OPENMP
48 #ifdef EIGEN_HAS_OPENMP
51 if(threads>1 && lhsEval.nonZerosEstimate() > 20000)
53 #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads)
55 processRow(lhsEval,rhs,
res,alpha,
i,
c);
61 processRow(lhsEval,rhs,
res,alpha,
i,
c);
66 static void processRow(
const LhsEval& lhsEval,
const DenseRhsType& rhs, DenseResType&
res,
const typename Res::Scalar& alpha,
Index i,
Index col)
70 typename Res::Scalar tmp_a(0);
71 typename Res::Scalar tmp_b(0);
72 for(LhsInnerIterator it(lhsEval,
i); it ;++it) {
73 tmp_a += it.value() * rhs.coeff(it.index(),
col);
76 tmp_b += it.value() * rhs.coeff(it.index(),
col);
79 res.coeffRef(
i,
col) += alpha * (tmp_a + tmp_b);
95 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
typename AlphaType>
96 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType,
ColMajor, true>
98 typedef internal::remove_all_t<SparseLhsType> Lhs;
99 typedef internal::remove_all_t<DenseRhsType> Rhs;
100 typedef internal::remove_all_t<DenseResType> Res;
101 typedef evaluator<Lhs> LhsEval;
102 typedef typename LhsEval::InnerIterator LhsInnerIterator;
103 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType&
res,
const AlphaType& alpha)
105 LhsEval lhsEval(lhs);
108 for(
Index j=0;
j<lhs.outerSize(); ++
j)
111 typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(
j,
c));
112 for(LhsInnerIterator it(lhsEval,
j); it ;++it)
113 res.coeffRef(it.index(),
c) += it.value() * rhs_j;
119 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
120 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar,
RowMajor, false>
122 typedef internal::remove_all_t<SparseLhsType> Lhs;
123 typedef internal::remove_all_t<DenseRhsType> Rhs;
124 typedef internal::remove_all_t<DenseResType> Res;
125 typedef evaluator<Lhs> LhsEval;
126 typedef typename LhsEval::InnerIterator LhsInnerIterator;
127 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType&
res,
const typename Res::Scalar& alpha)
130 LhsEval lhsEval(lhs);
132 #ifdef EIGEN_HAS_OPENMP
137 if(threads>1 && lhsEval.nonZerosEstimate()*rhs.cols() > 20000)
139 #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads)
141 processRow(lhsEval,rhs,
res,alpha,
i);
147 processRow(lhsEval, rhs,
res, alpha,
i);
151 static void processRow(
const LhsEval& lhsEval,
const DenseRhsType& rhs, Res&
res,
const typename Res::Scalar& alpha,
Index i)
153 typename Res::RowXpr res_i(
res.row(
i));
154 for(LhsInnerIterator it(lhsEval,
i); it ;++it)
155 res_i += (alpha*it.value()) * rhs.row(it.index());
159 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
160 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar,
ColMajor, false>
162 typedef internal::remove_all_t<SparseLhsType> Lhs;
163 typedef internal::remove_all_t<DenseRhsType> Rhs;
164 typedef internal::remove_all_t<DenseResType> Res;
165 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
166 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType&
res,
const typename Res::Scalar& alpha)
168 evaluator<Lhs> lhsEval(lhs);
169 for(
Index j=0;
j<lhs.outerSize(); ++
j)
171 typename Rhs::ConstRowXpr rhs_j(rhs.row(
j));
172 for(LhsInnerIterator it(lhsEval,
j); it ;++it)
173 res.row(it.index()) += (alpha*it.value()) * rhs_j;
178 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
typename AlphaType>
181 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs,
res, alpha);
188 template<
typename Lhs,
typename Rhs,
int ProductType>
189 struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
190 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> >
192 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
194 template<
typename Dest>
195 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
197 typedef typename nested_eval<Lhs,((Rhs::Flags&
RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested;
198 typedef typename nested_eval<Rhs,((Lhs::Flags&
RowMajorBit)==0) ? 1 :
Dynamic>::type RhsNested;
199 LhsNested lhsNested(lhs);
200 RhsNested rhsNested(rhs);
205 template<
typename Lhs,
typename Rhs,
int ProductType>
206 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType>
207 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
210 template<
typename Lhs,
typename Rhs,
int ProductType>
211 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
212 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> >
214 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
216 template<
typename Dst>
217 static void scaleAndAddTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
219 typedef typename nested_eval<Lhs,((Rhs::Flags&
RowMajorBit)==0) ?
Dynamic : 1>::type LhsNested;
220 typedef typename nested_eval<Rhs,((Lhs::Flags&
RowMajorBit)==
RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested;
221 LhsNested lhsNested(lhs);
222 RhsNested rhsNested(rhs);
225 Transpose<Dst> dstT(dst);
230 template<
typename Lhs,
typename Rhs,
int ProductType>
231 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType>
232 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
235 template<
typename LhsT,
typename RhsT,
bool NeedToTranspose>
236 struct sparse_dense_outer_product_evaluator
239 typedef std::conditional_t<NeedToTranspose,RhsT,LhsT> Lhs1;
240 typedef std::conditional_t<NeedToTranspose,LhsT,RhsT> ActualRhs;
241 typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType;
245 typedef std::conditional_t<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
246 Lhs1, SparseView<Lhs1> > ActualLhs;
247 typedef std::conditional_t<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
248 Lhs1
const&, SparseView<Lhs1> > LhsArg;
250 typedef evaluator<ActualLhs> LhsEval;
251 typedef evaluator<ActualRhs> RhsEval;
252 typedef typename evaluator<ActualLhs>::InnerIterator
LhsIterator;
253 typedef typename ProdXprType::Scalar Scalar;
264 InnerIterator(
const sparse_dense_outer_product_evaluator &xprEval,
Index outer)
268 m_factor(get(xprEval.m_rhsXprImpl, outer, typename
internal::traits<ActualRhs>::StorageKind() ))
271 EIGEN_STRONG_INLINE
Index outer()
const {
return m_outer; }
272 EIGEN_STRONG_INLINE
Index row()
const {
return NeedToTranspose ? m_outer : LhsIterator::index(); }
273 EIGEN_STRONG_INLINE
Index col()
const {
return NeedToTranspose ? LhsIterator::index() : m_outer; }
275 EIGEN_STRONG_INLINE Scalar value()
const {
return LhsIterator::value() * m_factor; }
276 EIGEN_STRONG_INLINE
operator bool()
const {
return LhsIterator::operator
bool() && (!m_empty); }
279 Scalar get(
const RhsEval &rhs,
Index outer, Dense = Dense())
const
281 return rhs.coeff(outer);
284 Scalar get(
const RhsEval &rhs,
Index outer, Sparse = Sparse())
286 typename RhsEval::InnerIterator it(rhs, outer);
287 if (it && it.index()==0 && it.value()!=Scalar(0))
298 sparse_dense_outer_product_evaluator(
const Lhs1 &lhs,
const ActualRhs &rhs)
299 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
305 sparse_dense_outer_product_evaluator(
const ActualRhs &rhs,
const Lhs1 &lhs)
306 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
313 evaluator<ActualLhs> m_lhsXprImpl;
314 evaluator<ActualRhs> m_rhsXprImpl;
318 template<
typename Lhs,
typename Rhs>
320 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor>
322 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base;
324 typedef Product<Lhs, Rhs> XprType;
325 typedef typename XprType::PlainObject PlainObject;
327 explicit product_evaluator(
const XprType& xpr)
328 : Base(xpr.lhs(), xpr.rhs())
333 template<
typename Lhs,
typename Rhs>
335 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor>
337 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base;
339 typedef Product<Lhs, Rhs> XprType;
340 typedef typename XprType::PlainObject PlainObject;
342 explicit product_evaluator(
const XprType& xpr)
343 : Base(xpr.lhs(), xpr.rhs())
RowXpr row(Index i)
This is the const version of row(). */.
ColXpr col(Index i)
This is the const version of col().
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
#define EIGEN_INTERNAL_CHECK_COST_VALUE(C)
const unsigned int RowMajorBit
void sparse_time_dense_product(const SparseLhsType &lhs, const DenseRhsType &rhs, DenseResType &res, const AlphaType &alpha)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.