11 #ifndef EIGEN_GENERAL_PRODUCT_H
12 #define EIGEN_GENERAL_PRODUCT_H
28 #ifndef EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
30 #define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 20
35 template<
int Rows,
int Cols,
int Depth>
struct product_type_selector;
37 template<
int Size,
int MaxSize>
struct product_size_category
40 #ifndef EIGEN_GPU_COMPILE_PHASE
41 is_large = MaxSize ==
Dynamic ||
47 value = is_large ?
Large
53 template<
typename Lhs,
typename Rhs>
struct product_type
55 typedef remove_all_t<Lhs> Lhs_;
56 typedef remove_all_t<Rhs> Rhs_;
58 MaxRows = traits<Lhs_>::MaxRowsAtCompileTime,
59 Rows = traits<Lhs_>::RowsAtCompileTime,
60 MaxCols = traits<Rhs_>::MaxColsAtCompileTime,
61 Cols = traits<Rhs_>::ColsAtCompileTime,
63 traits<Rhs_>::MaxRowsAtCompileTime),
65 traits<Rhs_>::RowsAtCompileTime)
72 rows_select = product_size_category<Rows,MaxRows>::value,
73 cols_select = product_size_category<Cols,MaxCols>::value,
74 depth_select = product_size_category<Depth,MaxDepth>::value
76 typedef product_type_selector<rows_select, cols_select, depth_select> selector;
80 value = selector::ret,
83 #ifdef EIGEN_DEBUG_PRODUCT
101 template<
int M,
int N>
struct product_type_selector<
M,N,1> {
enum { ret =
OuterProduct }; };
104 template<
int Depth>
struct product_type_selector<1, 1, Depth> {
enum { ret =
InnerProduct }; };
105 template<>
struct product_type_selector<1, 1, 1> {
enum { ret =
InnerProduct }; };
156 template<
int S
ide,
int StorageOrder,
bool BlasCompatible>
157 struct gemv_dense_selector;
163 template<
typename Scalar,
int Size,
int MaxSize,
bool Cond>
struct gemv_static_vector_if;
165 template<
typename Scalar,
int Size,
int MaxSize>
166 struct gemv_static_vector_if<Scalar,Size,MaxSize,false>
171 template<
typename Scalar,
int Size>
172 struct gemv_static_vector_if<Scalar,Size,
Dynamic,true>
177 template<
typename Scalar,
int Size,
int MaxSize>
178 struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
181 ForceAlignment = internal::packet_traits<Scalar>::Vectorizable,
184 #if EIGEN_MAX_STATIC_ALIGN_BYTES!=0
187 EIGEN_STRONG_INLINE Scalar*
data() {
return m_data.array; }
192 EIGEN_STRONG_INLINE Scalar*
data() {
193 return ForceAlignment
201 template<
int StorageOrder,
bool BlasCompatible>
202 struct gemv_dense_selector<
OnTheLeft,StorageOrder,BlasCompatible>
204 template<
typename Lhs,
typename Rhs,
typename Dest>
205 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
207 Transpose<Dest> destT(dest);
209 gemv_dense_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
210 ::run(rhs.transpose(), lhs.transpose(), destT, alpha);
216 template<
typename Lhs,
typename Rhs,
typename Dest>
217 static inline void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
219 typedef typename Lhs::Scalar LhsScalar;
220 typedef typename Rhs::Scalar RhsScalar;
221 typedef typename Dest::Scalar ResScalar;
223 typedef internal::blas_traits<Lhs> LhsBlasTraits;
224 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
225 typedef internal::blas_traits<Rhs> RhsBlasTraits;
226 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
230 ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
231 ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
236 typedef std::conditional_t<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr> ActualDest;
241 EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime==1),
243 MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime!=0)
246 typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
247 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
248 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
250 if(!MightCannotUseDest)
254 general_matrix_vector_product
255 <
Index,LhsScalar,LhsMapper,
ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
256 actualLhs.rows(), actualLhs.cols(),
257 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
258 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
264 gemv_static_vector_if<ResScalar,ActualDest::SizeAtCompileTime,ActualDest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
267 const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
270 evalToDest ? dest.data() : static_dest.data());
274 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
276 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
278 if(!alphaIsCompatible)
280 MappedDest(actualDestPtr, dest.size()).setZero();
281 compatibleAlpha = RhsScalar(1);
284 MappedDest(actualDestPtr, dest.size()) = dest;
287 general_matrix_vector_product
288 <
Index,LhsScalar,LhsMapper,
ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
289 actualLhs.rows(), actualLhs.cols(),
290 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
291 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
297 if(!alphaIsCompatible)
298 dest.matrix() += actualAlpha * MappedDest(actualDestPtr, dest.size());
300 dest = MappedDest(actualDestPtr, dest.size());
308 template<
typename Lhs,
typename Rhs,
typename Dest>
309 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
311 typedef typename Lhs::Scalar LhsScalar;
312 typedef typename Rhs::Scalar RhsScalar;
313 typedef typename Dest::Scalar ResScalar;
315 typedef internal::blas_traits<Lhs> LhsBlasTraits;
316 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
317 typedef internal::blas_traits<Rhs> RhsBlasTraits;
318 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
319 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
321 std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
322 std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
329 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime==0
332 gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
335 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
339 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
341 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
343 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
346 typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
347 typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
348 general_matrix_vector_product
349 <
Index,LhsScalar,LhsMapper,
RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
350 actualLhs.rows(), actualLhs.cols(),
351 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
352 RhsMapper(actualRhsPtr, 1),
353 dest.data(), dest.col(0).innerStride(),
360 template<
typename Lhs,
typename Rhs,
typename Dest>
361 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
363 EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
365 typename nested_eval<Rhs,1>::type actual_rhs(rhs);
368 dest += (alpha*actual_rhs.coeff(k)) * lhs.col(k);
374 template<
typename Lhs,
typename Rhs,
typename Dest>
375 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
377 EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
378 typename nested_eval<Rhs,Lhs::RowsAtCompileTime>::type actual_rhs(rhs);
381 dest.coeffRef(
i) += alpha * (lhs.row(
i).cwiseProduct(actual_rhs.transpose())).sum();
397 template<
typename Derived>
398 template<
typename OtherDerived>
400 const Product<Derived, OtherDerived>
408 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic
409 || OtherDerived::RowsAtCompileTime==
Dynamic
410 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
411 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
418 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
420 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
422 #ifdef EIGEN_DEBUG_PRODUCT
423 internal::product_type<Derived,OtherDerived>::debug();
440 template<
typename Derived>
441 template<
typename OtherDerived>
447 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic
448 || OtherDerived::RowsAtCompileTime==
Dynamic
449 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
450 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
457 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
459 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
const ImagReturnType imag() const
#define eigen_internal_assert(x)
#define EIGEN_DEBUG_VAR(x)
#define EIGEN_DEVICE_FUNC
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
#define EIGEN_STATIC_ASSERT(X, MSG)
#define EIGEN_PREDICATE_SAME_MATRIX_SIZE(TYPE0, TYPE1)
internal::traits< Homogeneous< MatrixType, Direction_ > >::Scalar Scalar
Base class for all dense matrices, vectors, and expressions.
const Product< Derived, OtherDerived, LazyProduct > lazyProduct(const MatrixBase< OtherDerived > &other) const
const Product< Derived, OtherDerived > operator*(const MatrixBase< OtherDerived > &other) const
Expression of the product of two arbitrary matrices or vectors.
constexpr int plain_enum_min(A a, B b)
EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar &alpha, const Lhs &lhs, const Rhs &rhs)
constexpr int min_size_prefer_fixed(A a, B b)
bool is_exactly_zero(const X &x)
@ LazyCoeffBasedProductMode
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.