10 #ifndef EIGEN_BLASUTIL_H
11 #define EIGEN_BLASUTIL_H
16 #include "../InternalHeaderCheck.h"
23 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs=false,
bool ConjugateRhs=false>
26 template<
typename Scalar,
typename Index,
typename DataMapper,
int nr,
int StorageOrder,
bool Conjugate = false,
bool PanelMode=false>
29 template<
typename Scalar,
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
int StorageOrder,
bool Conjugate = false,
bool PanelMode = false>
34 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
35 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
36 int ResStorageOrder,
int ResInnerStride>
37 struct general_matrix_matrix_product;
39 template<
typename Index,
40 typename LhsScalar,
typename LhsMapper,
int LhsStorageOrder,
bool ConjugateLhs,
41 typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version=
Specialized>
42 struct general_matrix_vector_product;
44 template<
typename From,
typename To>
struct get_factor {
48 template<
typename Scalar>
struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
54 template<
typename Scalar,
typename Index>
55 class BlasVectorMapper {
62 template <
typename Packet,
int AlignmentType>
64 return ploadt<Packet, AlignmentType>(m_data +
i);
67 template <
typename Packet>
69 return (std::uintptr_t(m_data+
i)%
sizeof(Packet))==0;
76 template<
typename Scalar,
typename Index,
int AlignmentType,
int Incr=1>
77 class BlasLinearMapper;
79 template<
typename Scalar,
typename Index,
int AlignmentType>
98 template<
typename PacketType>
100 return ploadt<PacketType, AlignmentType>(m_data +
i);
103 template<
typename PacketType>
105 return ploadt_partial<PacketType, AlignmentType>(m_data +
i,
n, offset);
108 template<
typename PacketType,
int AlignmentT>
110 return ploadt<PacketType, AlignmentT>(m_data +
i);
113 template<
typename PacketType>
115 pstoret<Scalar, PacketType, AlignmentType>(m_data +
i,
p);
118 template<
typename PacketType>
120 pstoret_partial<Scalar, PacketType, AlignmentType>(m_data +
i,
p,
n, offset);
128 template<
typename Scalar,
typename Index,
int StorageOrder,
int AlignmentType = Unaligned,
int Incr = 1>
129 class blas_data_mapper;
135 template<
typename Index,
typename Scalar,
typename Packet,
int n,
int idx,
int StorageOrder>
136 struct PacketBlockManagement
138 PacketBlockManagement<
Index, Scalar, Packet,
n, idx - 1, StorageOrder> pbm;
140 pbm.store(to, stride,
i,
j,
block);
141 pstoreu<Scalar>(to +
i + (
j + idx)*stride,
block.packet[idx]);
146 template<
typename Index,
typename Scalar,
typename Packet,
int n,
int idx>
147 struct PacketBlockManagement<
Index, Scalar, Packet,
n, idx,
RowMajor>
149 PacketBlockManagement<
Index, Scalar, Packet,
n, idx - 1,
RowMajor> pbm;
151 pbm.store(to, stride,
i,
j,
block);
152 pstoreu<Scalar>(to +
j + (
i + idx)*stride,
block.packet[idx]);
156 template<
typename Index,
typename Scalar,
typename Packet,
int n,
int StorageOrder>
157 struct PacketBlockManagement<
Index, Scalar, Packet,
n, -1, StorageOrder>
168 template<
typename Index,
typename Scalar,
typename Packet,
int n>
169 struct PacketBlockManagement<
Index, Scalar, Packet,
n, -1,
RowMajor>
180 template<
typename Scalar,
typename Index,
int StorageOrder,
int AlignmentType>
184 typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
185 typedef BlasVectorMapper<Scalar, Index> VectorMapper;
188 : m_data(
data), m_stride(stride)
196 return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&
operator()(
i,
j), m_stride);
200 return LinearMapper(&
operator()(
i,
j));
204 return VectorMapper(&
operator()(
i,
j));
213 return m_data[StorageOrder==
RowMajor ?
j +
i*m_stride :
i +
j*m_stride];
216 template<
typename PacketType>
218 return ploadt<PacketType, AlignmentType>(&
operator()(
i,
j));
221 template<
typename PacketType>
223 return ploadt_partial<PacketType, AlignmentType>(&
operator()(
i,
j),
n, offset);
226 template <
typename PacketT,
int AlignmentT>
228 return ploadt<PacketT, AlignmentT>(&
operator()(
i,
j));
231 template<
typename PacketType>
233 pstoret<Scalar, PacketType, AlignmentType>(&
operator()(
i,
j),
p);
236 template<
typename PacketType>
238 pstoret_partial<Scalar, PacketType, AlignmentType>(&
operator()(
i,
j),
p,
n, offset);
241 template<
typename SubPacket>
243 pscatter<Scalar, SubPacket>(&
operator()(
i,
j),
p, m_stride);
246 template<
typename SubPacket>
248 return pgather<Scalar, SubPacket>(&
operator()(
i,
j), m_stride);
256 if (std::uintptr_t(m_data)%
sizeof(Scalar)) {
262 template<
typename SubPacket,
int n>
264 PacketBlockManagement<
Index, Scalar, SubPacket,
n,
n-1, StorageOrder> pbm;
265 pbm.store(m_data, m_stride,
i,
j,
block);
269 const Index m_stride;
275 template<
typename Scalar,
typename Index,
int AlignmentType,
int Incr>
276 class BlasLinearMapper
286 return m_data[
i*m_incr.value()];
289 template<
typename PacketType>
291 return pgather<Scalar,PacketType>(m_data +
i*m_incr.value(), m_incr.value());
294 template<
typename PacketType>
296 return pgather_partial<Scalar,PacketType>(m_data +
i*m_incr.value(), m_incr.value(),
n);
299 template<
typename PacketType>
301 pscatter<Scalar, PacketType>(m_data +
i*m_incr.value(),
p, m_incr.value());
304 template<
typename PacketType>
306 pscatter_partial<Scalar, PacketType>(m_data +
i*m_incr.value(),
p, m_incr.value(),
n);
311 const internal::variable_if_dynamic<Index,Incr> m_incr;
314 template<
typename Scalar,
typename Index,
int StorageOrder,
int AlignmentType,
int Incr>
315 class blas_data_mapper
318 typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
324 return blas_data_mapper(&
operator()(
i,
j), m_stride, m_incr.value());
328 return LinearMapper(&
operator()(
i,
j), m_incr.value());
337 return m_data[StorageOrder==
RowMajor ?
j*m_incr.value() +
i*m_stride :
i*m_incr.value() +
j*m_stride];
340 template<
typename PacketType>
342 return pgather<Scalar,PacketType>(&
operator()(
i,
j),m_incr.value());
345 template<
typename PacketType>
347 return pgather_partial<Scalar,PacketType>(&
operator()(
i,
j),m_incr.value(),
n);
350 template <
typename PacketT,
int AlignmentT>
352 return pgather<Scalar,PacketT>(&
operator()(
i,
j),m_incr.value());
355 template<
typename PacketType>
357 pscatter<Scalar, PacketType>(&
operator()(
i,
j),
p, m_incr.value());
360 template<
typename PacketType>
362 pscatter_partial<Scalar, PacketType>(&
operator()(
i,
j),
p, m_incr.value(),
n);
365 template<
typename SubPacket>
367 pscatter<Scalar, SubPacket>(&
operator()(
i,
j),
p, m_stride);
370 template<
typename SubPacket>
372 return pgather<Scalar, SubPacket>(&
operator()(
i,
j), m_stride);
376 template<
typename SubPacket,
typename Scalar_,
int n,
int idx>
377 struct storePacketBlock_helper
379 storePacketBlock_helper<SubPacket, Scalar_,
n, idx-1> spbh;
382 sup->template storePacket<SubPacket>(
i,
j+idx,
block.packet[idx]);
386 template<
typename SubPacket,
int n,
int idx>
387 struct storePacketBlock_helper<SubPacket,
std::complex<float>,
n, idx>
389 storePacketBlock_helper<SubPacket, std::complex<float>,
n, idx-1> spbh;
392 sup->template storePacket<SubPacket>(
i,
j+idx,
block.packet[idx]);
396 template<
typename SubPacket,
int n,
int idx>
397 struct storePacketBlock_helper<SubPacket,
std::complex<double>,
n, idx>
399 storePacketBlock_helper<SubPacket, std::complex<double>,
n, idx-1> spbh;
404 std::complex<double> *
v = &sup->operator()(
i+l,
j+idx);
405 v->real(
block.packet[idx].v[2*l+0]);
406 v->imag(
block.packet[idx].v[2*l+1]);
411 template<
typename SubPacket,
typename Scalar_,
int n>
412 struct storePacketBlock_helper<SubPacket, Scalar_,
n, -1>
418 template<
typename SubPacket,
int n>
419 struct storePacketBlock_helper<SubPacket,
std::complex<float>,
n, -1>
425 template<
typename SubPacket,
int n>
426 struct storePacketBlock_helper<SubPacket,
std::complex<double>,
n, -1>
432 template<
typename SubPacket,
int n>
434 storePacketBlock_helper<SubPacket, Scalar,
n,
n-1> spb;
443 const Index m_stride;
444 const internal::variable_if_dynamic<Index,Incr> m_incr;
448 template<
typename Scalar,
typename Index,
int StorageOrder>
449 class const_blas_data_mapper :
public blas_data_mapper<const Scalar, Index, StorageOrder> {
454 return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->
operator()(i,
j)), this->m_stride);
462 template<
typename XprType>
struct blas_traits
464 typedef typename traits<XprType>::Scalar Scalar;
465 typedef const XprType& ExtractType;
466 typedef XprType ExtractType_;
469 IsTransposed =
false,
470 NeedToConjugate =
false,
472 && (
bool(XprType::IsVectorAtCompileTime)
473 || int(inner_stride_at_compile_time<XprType>::ret) == 1)
475 HasScalarFactor =
false
477 typedef std::conditional_t<
bool(HasUsableDirectAccess),
479 typename ExtractType_::PlainObject
480 > DirectLinearAccessType;
486 template<
typename Scalar,
typename NestedXpr>
487 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
488 : blas_traits<NestedXpr>
490 typedef blas_traits<NestedXpr> Base;
491 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
492 typedef typename Base::ExtractType ExtractType;
496 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
498 EIGEN_DEVICE_FUNC static inline ExtractType extract(
const XprType&
x) {
return Base::extract(
x.nestedExpression()); }
499 EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(
const XprType&
x) {
return conj(Base::extractScalarFactor(
x.nestedExpression())); }
503 template<
typename Scalar,
typename NestedXpr,
typename Plain>
504 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
505 : blas_traits<NestedXpr>
508 HasScalarFactor =
true
510 typedef blas_traits<NestedXpr> Base;
511 typedef CwiseBinaryOp<scalar_product_op<Scalar>,
const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
512 typedef typename Base::ExtractType ExtractType;
515 {
return x.lhs().functor().m_other * Base::extractScalarFactor(
x.rhs()); }
517 template<
typename Scalar,
typename NestedXpr,
typename Plain>
518 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
519 : blas_traits<NestedXpr>
522 HasScalarFactor =
true
524 typedef blas_traits<NestedXpr> Base;
525 typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr,
const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
526 typedef typename Base::ExtractType ExtractType;
527 EIGEN_DEVICE_FUNC static inline ExtractType extract(
const XprType&
x) {
return Base::extract(
x.lhs()); }
529 {
return Base::extractScalarFactor(
x.lhs()) *
x.rhs().functor().m_other; }
531 template<
typename Scalar,
typename Plain1,
typename Plain2>
532 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>,
533 const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > >
534 : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> >
538 template<
typename Scalar,
typename NestedXpr>
539 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
540 : blas_traits<NestedXpr>
543 HasScalarFactor =
true
545 typedef blas_traits<NestedXpr> Base;
546 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
547 typedef typename Base::ExtractType ExtractType;
548 EIGEN_DEVICE_FUNC static inline ExtractType extract(
const XprType&
x) {
return Base::extract(
x.nestedExpression()); }
550 {
return - Base::extractScalarFactor(
x.nestedExpression()); }
554 template<
typename NestedXpr>
555 struct blas_traits<Transpose<NestedXpr> >
556 : blas_traits<NestedXpr>
558 typedef typename NestedXpr::Scalar Scalar;
559 typedef blas_traits<NestedXpr> Base;
560 typedef Transpose<NestedXpr> XprType;
561 typedef Transpose<const typename Base::ExtractType_> ExtractType;
562 typedef Transpose<const typename Base::ExtractType_> ExtractType_;
563 typedef std::conditional_t<
bool(Base::HasUsableDirectAccess),
565 typename ExtractType::PlainObject
566 > DirectLinearAccessType;
568 IsTransposed = Base::IsTransposed ? 0 : 1
570 EIGEN_DEVICE_FUNC static inline ExtractType extract(
const XprType&
x) {
return ExtractType(Base::extract(
x.nestedExpression())); }
571 EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(
const XprType&
x) {
return Base::extractScalarFactor(
x.nestedExpression()); }
575 struct blas_traits<const
T>
579 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
580 struct extract_data_selector {
583 return blas_traits<T>::extract(
m).data();
588 struct extract_data_selector<
T,false> {
595 return extract_data_selector<T>::run(
m);
602 template<
typename ResScalar,
typename Lhs,
typename Rhs>
603 struct combine_scalar_factors_impl
607 return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
611 return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
614 template<
typename Lhs,
typename Rhs>
615 struct combine_scalar_factors_impl<
bool, Lhs, Rhs>
619 return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
623 return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
627 template<
typename ResScalar,
typename Lhs,
typename Rhs>
630 return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(alpha, lhs, rhs);
632 template<
typename ResScalar,
typename Lhs,
typename Rhs>
635 return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(lhs, rhs);
Array< int, Dynamic, 1 > v
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL FixedBlockXpr<...,... >::Type block(Index startRow, Index startCol, NRowsType blockRows, NColsType blockCols)
RealReturnType real() const
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_DEVICE_FUNC
#define EIGEN_ONLY_USED_FOR_DEBUG(x)
const unsigned int DirectAccessBit
EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar &alpha, const Lhs &lhs, const Rhs &rhs)
static Index first_default_aligned(const DenseBase< Derived > &m)
void prefetch(const Scalar *addr)
EIGEN_ALWAYS_INLINE const T::Scalar * extract_data(const T &m)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)