10 #ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_H
11 #define EIGEN_SELFADJOINT_MATRIX_MATRIX_H
13 #include "../InternalHeaderCheck.h"
20 template<
typename Scalar,
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
23 template<
int BlockRows>
inline
24 void pack(Scalar* blockA,
const const_blas_data_mapper<Scalar,Index,StorageOrder>& lhs,
Index cols,
Index i,
Index& count)
29 blockA[count++] = lhs(
i+
w,k);
32 for(
Index k=
i; k<
i+BlockRows; k++)
40 blockA[count++] = lhs(
i+
w, k);
50 typedef typename unpacket_traits<typename packet_traits<Scalar>::type>::half HalfPacket;
51 typedef typename unpacket_traits<typename unpacket_traits<typename packet_traits<Scalar>::type>::half>::half QuarterPacket;
55 HasHalf = (int)HalfPacketSize < (
int)PacketSize,
56 HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize};
58 const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(_lhs,lhsStride);
62 const Index peeled_mc3 = Pack1>=3*PacketSize ? (
rows/(3*PacketSize))*(3*PacketSize) : 0;
63 const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((
rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
64 const Index peeled_mc1 = Pack1>=1*PacketSize ? peeled_mc2+((
rows-peeled_mc2)/(1*PacketSize))*(1*PacketSize) : 0;
65 const Index peeled_mc_half = Pack1>=HalfPacketSize ? peeled_mc1+((
rows-peeled_mc1)/(HalfPacketSize))*(HalfPacketSize) : 0;
66 const Index peeled_mc_quarter = Pack1>=QuarterPacketSize ? peeled_mc_half+((
rows-peeled_mc_half)/(QuarterPacketSize))*(QuarterPacketSize) : 0;
68 if(Pack1>=3*PacketSize)
69 for(
Index i=0;
i<peeled_mc3;
i+=3*PacketSize)
70 pack<3*PacketSize>(blockA, lhs,
cols,
i, count);
72 if(Pack1>=2*PacketSize)
73 for(
Index i=peeled_mc3;
i<peeled_mc2;
i+=2*PacketSize)
74 pack<2*PacketSize>(blockA, lhs,
cols,
i, count);
76 if(Pack1>=1*PacketSize)
77 for(
Index i=peeled_mc2;
i<peeled_mc1;
i+=1*PacketSize)
78 pack<1*PacketSize>(blockA, lhs,
cols,
i, count);
80 if(HasHalf && Pack1>=HalfPacketSize)
81 for(
Index i=peeled_mc1;
i<peeled_mc_half;
i+=HalfPacketSize)
82 pack<HalfPacketSize>(blockA, lhs,
cols,
i, count);
84 if(HasQuarter && Pack1>=QuarterPacketSize)
85 for(
Index i=peeled_mc_half;
i<peeled_mc_quarter;
i+=QuarterPacketSize)
86 pack<QuarterPacketSize>(blockA, lhs,
cols,
i, count);
92 blockA[count++] = lhs(
i, k);
102 template<
typename Scalar,
typename Index,
int nr,
int StorageOrder>
110 const_blas_data_mapper<Scalar,Index,StorageOrder> rhs(_rhs,rhsStride);
111 Index packet_cols8 = nr>=8 ? (
cols/8) * 8 : 0;
112 Index packet_cols4 = nr>=4 ? (
cols/4) * 4 : 0;
115 for(
Index j2=0; j2<k2; j2+=nr)
117 for(
Index k=k2; k<end_k; k++)
119 blockB[count+0] = rhs(k,j2+0);
120 blockB[count+1] = rhs(k,j2+1);
123 blockB[count+2] = rhs(k,j2+2);
124 blockB[count+3] = rhs(k,j2+3);
128 blockB[count+4] = rhs(k,j2+4);
129 blockB[count+5] = rhs(k,j2+5);
130 blockB[count+6] = rhs(k,j2+6);
131 blockB[count+7] = rhs(k,j2+7);
141 for(
Index j2=k2; j2<end8; j2+=8)
145 for(
Index k=k2; k<j2; k++)
159 for(
Index k=j2; k<j2+8; k++)
163 blockB[count+
w] = rhs(k,j2+
w);
174 for(
Index k=j2+8; k<end_k; k++)
176 blockB[count+0] = rhs(k,j2+0);
177 blockB[count+1] = rhs(k,j2+1);
178 blockB[count+2] = rhs(k,j2+2);
179 blockB[count+3] = rhs(k,j2+3);
180 blockB[count+4] = rhs(k,j2+4);
181 blockB[count+5] = rhs(k,j2+5);
182 blockB[count+6] = rhs(k,j2+6);
183 blockB[count+7] = rhs(k,j2+7);
194 for(
Index k=k2; k<j2; k++)
204 for(
Index k=j2; k<j2+4; k++)
208 blockB[count+
w] = rhs(k,j2+
w);
219 for(
Index k=j2+4; k<end_k; k++)
221 blockB[count+0] = rhs(k,j2+0);
222 blockB[count+1] = rhs(k,j2+1);
223 blockB[count+2] = rhs(k,j2+2);
224 blockB[count+3] = rhs(k,j2+3);
233 for(
Index j2=k2+
rows; j2<packet_cols8; j2+=8)
235 for(
Index k=k2; k<end_k; k++)
253 for(
Index k=k2; k<end_k; k++)
265 for(
Index j2=packet_cols4; j2<
cols; ++j2)
269 for(
Index k=k2; k<half; k++)
275 if(half==j2 && half<k2+
rows)
286 blockB[count] = rhs(k,j2);
296 template <
typename Scalar,
typename Index,
297 int LhsStorageOrder,
bool LhsSelfAdjoint,
bool ConjugateLhs,
298 int RhsStorageOrder,
bool RhsSelfAdjoint,
bool ConjugateRhs,
299 int ResStorageOrder,
int ResInnerStride>
300 struct product_selfadjoint_matrix;
302 template <
typename Scalar,
typename Index,
303 int LhsStorageOrder,
bool LhsSelfAdjoint,
bool ConjugateLhs,
304 int RhsStorageOrder,
bool RhsSelfAdjoint,
bool ConjugateRhs,
306 struct product_selfadjoint_matrix<Scalar,
Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,
RowMajor,ResInnerStride>
309 static EIGEN_STRONG_INLINE
void run(
311 const Scalar* lhs,
Index lhsStride,
312 const Scalar* rhs,
Index rhsStride,
314 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
316 product_selfadjoint_matrix<Scalar,
Index,
322 ::run(
cols,
rows, rhs, rhsStride, lhs, lhsStride,
res, resIncr, resStride, alpha, blocking);
326 template <
typename Scalar,
typename Index,
327 int LhsStorageOrder,
bool ConjugateLhs,
328 int RhsStorageOrder,
bool ConjugateRhs,
330 struct product_selfadjoint_matrix<Scalar,
Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,
ColMajor,ResInnerStride>
335 const Scalar* _lhs,
Index lhsStride,
336 const Scalar* _rhs,
Index rhsStride,
338 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
341 template <
typename Scalar,
typename Index,
342 int LhsStorageOrder,
bool ConjugateLhs,
343 int RhsStorageOrder,
bool ConjugateRhs,
345 EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>::run(
347 const Scalar* _lhs,
Index lhsStride,
348 const Scalar* _rhs,
Index rhsStride,
350 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
354 typedef gebp_traits<Scalar,Scalar> Traits;
356 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
358 typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
359 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
360 LhsMapper lhs(_lhs,lhsStride);
361 LhsTransposeMapper lhs_transpose(_lhs,lhsStride);
362 RhsMapper rhs(_rhs,rhsStride);
363 ResMapper
res(_res, resStride, resIncr);
365 Index kc = blocking.kc();
369 std::size_t sizeA = kc*mc;
370 std::size_t sizeB = kc*
cols;
374 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
375 symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
376 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
377 gemm_pack_lhs<Scalar, Index, LhsTransposeMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed;
386 pack_rhs(blockB, rhs.getSubMapper(k2,0), actual_kc,
cols);
392 for(
Index i2=0; i2<k2; i2+=mc)
396 pack_lhs_transposed(blockA, lhs_transpose.getSubMapper(i2, k2), actual_kc, actual_mc);
398 gebp_kernel(
res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
cols, alpha);
404 pack_lhs(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc);
406 gebp_kernel(
res.getSubMapper(k2, 0), blockA, blockB, actual_mc, actual_kc,
cols, alpha);
412 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder,false>()
413 (blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
415 gebp_kernel(
res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
cols, alpha);
421 template <
typename Scalar,
typename Index,
422 int LhsStorageOrder,
bool ConjugateLhs,
423 int RhsStorageOrder,
bool ConjugateRhs,
425 struct product_selfadjoint_matrix<Scalar,
Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,
ColMajor,ResInnerStride>
430 const Scalar* _lhs,
Index lhsStride,
431 const Scalar* _rhs,
Index rhsStride,
433 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
436 template <
typename Scalar,
typename Index,
437 int LhsStorageOrder,
bool ConjugateLhs,
438 int RhsStorageOrder,
bool ConjugateRhs,
440 EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
442 const Scalar* _lhs,
Index lhsStride,
443 const Scalar* _rhs,
Index rhsStride,
445 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
449 typedef gebp_traits<Scalar,Scalar> Traits;
451 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
452 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
453 LhsMapper lhs(_lhs,lhsStride);
454 ResMapper
res(_res,resStride, resIncr);
456 Index kc = blocking.kc();
458 std::size_t sizeA = kc*mc;
459 std::size_t sizeB = kc*
cols;
463 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
464 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
465 symm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
471 pack_rhs(blockB, _rhs, rhsStride, actual_kc,
cols, k2);
477 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
479 gebp_kernel(
res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
cols, alpha);
492 template<
typename Lhs,
int LhsMode,
typename Rhs,
int RhsMode>
493 struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false>
495 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
497 typedef internal::blas_traits<Lhs> LhsBlasTraits;
498 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
499 typedef internal::blas_traits<Rhs> RhsBlasTraits;
500 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
509 template<
typename Dest>
510 static void run(Dest &dst,
const Lhs &a_lhs,
const Rhs &a_rhs,
const Scalar& alpha)
512 eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
514 add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
515 add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
517 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
518 * RhsBlasTraits::extractScalarFactor(a_rhs);
521 Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,1> BlockingType;
523 BlockingType blocking(lhs.rows(), rhs.cols(), lhs.cols(), 1,
false);
525 internal::product_selfadjoint_matrix<Scalar,
Index,
531 Dest::InnerStrideAtCompileTime>
533 lhs.rows(), rhs.cols(),
534 &lhs.coeffRef(0,0), lhs.outerStride(),
535 &rhs.coeffRef(0,0), rhs.outerStride(),
536 &dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(),
537 actualAlpha, blocking
RealReturnType real() const
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_DONT_INLINE
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
const unsigned int RowMajorBit
bfloat16() max(const bfloat16 &a, const bfloat16 &b)
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
constexpr bool logical_xor(bool a, bool b)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)