10 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
11 #define EIGEN_GENERAL_MATRIX_MATRIX_H
13 #include "../InternalHeaderCheck.h"
19 template<
typename LhsScalar_,
typename RhsScalar_>
class level3_blocking;
24 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
25 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
27 struct general_matrix_matrix_product<
Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,
RowMajor,ResInnerStride>
29 typedef gebp_traits<RhsScalar,LhsScalar> Traits;
31 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
32 static EIGEN_STRONG_INLINE
void run(
34 const LhsScalar* lhs,
Index lhsStride,
35 const RhsScalar* rhs,
Index rhsStride,
38 level3_blocking<RhsScalar,LhsScalar>& blocking,
39 GemmParallelInfo<Index>* info = 0)
42 general_matrix_matrix_product<
Index,
46 ::run(
cols,
rows,depth,rhs,rhsStride,lhs,lhsStride,
res,resIncr,resStride,alpha,blocking,info);
54 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
55 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
57 struct general_matrix_matrix_product<
Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,
ColMajor,ResInnerStride>
60 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
62 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
64 const LhsScalar* _lhs,
Index lhsStride,
65 const RhsScalar* _rhs,
Index rhsStride,
66 ResScalar* _res,
Index resIncr,
Index resStride,
68 level3_blocking<LhsScalar,RhsScalar>& blocking,
69 GemmParallelInfo<Index>* info = 0)
71 typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
72 typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
73 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper;
74 LhsMapper lhs(_lhs, lhsStride);
75 RhsMapper rhs(_rhs, rhsStride);
76 ResMapper
res(_res, resStride, resIncr);
78 Index kc = blocking.kc();
82 gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
83 gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
84 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
86 #ifdef EIGEN_HAS_OPENMP
90 int tid = omp_get_thread_num();
91 int threads = omp_get_num_threads();
93 LhsScalar* blockA = blocking.blockA();
96 std::size_t sizeB = kc*nc;
100 for(
Index k=0; k<depth; k+=kc)
106 pack_rhs(blockB, rhs.getSubMapper(k,0), actual_kc, nc);
114 while(info[tid].users!=0) {}
115 info[tid].users = threads;
117 pack_lhs(blockA+info[tid].lhs_start*actual_kc, lhs.getSubMapper(info[tid].lhs_start,k), actual_kc, info[tid].lhs_length);
123 for(
int shift=0; shift<threads; ++shift)
125 int i = (tid+shift)%threads;
131 while(info[
i].sync!=k) {
135 gebp(
res.getSubMapper(info[
i].lhs_start, 0), blockA+info[
i].lhs_start*actual_kc, blockB, info[
i].lhs_length, actual_kc, nc, alpha);
144 pack_rhs(blockB, rhs.getSubMapper(k,
j), actual_kc, actual_nc);
147 gebp(
res.getSubMapper(0,
j), blockA, blockB,
rows, actual_kc, actual_nc, alpha);
162 std::size_t sizeA = kc*mc;
163 std::size_t sizeB = kc*nc;
168 const bool pack_rhs_once = mc!=
rows && kc==depth && nc==
cols;
175 for(
Index k2=0; k2<depth; k2+=kc)
183 pack_lhs(blockA, lhs.getSubMapper(i2,k2), actual_kc, actual_mc);
193 if((!pack_rhs_once) || i2==0)
194 pack_rhs(blockB, rhs.getSubMapper(k2,j2), actual_kc, actual_nc);
197 gebp(
res.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, alpha);
211 template<
typename Scalar,
typename Index,
typename Gemm,
typename Lhs,
typename Rhs,
typename Dest,
typename BlockingType>
214 gemm_functor(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const Scalar& actualAlpha, BlockingType& blocking)
215 : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha), m_blocking(blocking)
218 void initParallelSession(
Index num_threads)
const
220 m_blocking.initParallel(m_lhs.rows(), m_rhs.cols(), m_lhs.cols(), num_threads);
221 m_blocking.allocateA();
230 &m_lhs.coeffRef(
row,0), m_lhs.outerStride(),
231 &m_rhs.coeffRef(0,
col), m_rhs.outerStride(),
232 (Scalar*)&(m_dest.coeffRef(
row,
col)), m_dest.innerStride(), m_dest.outerStride(),
233 m_actualAlpha, m_blocking, info);
236 typedef typename Gemm::Traits Traits;
242 Scalar m_actualAlpha;
243 BlockingType& m_blocking;
246 template<
int StorageOrder,
typename LhsScalar,
typename RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor=1,
247 bool FiniteAtCompileTime = MaxRows!=
Dynamic && MaxCols!=
Dynamic && MaxDepth !=
Dynamic>
class gemm_blocking_space;
249 template<
typename LhsScalar_,
typename RhsScalar_>
250 class level3_blocking
252 typedef LhsScalar_ LhsScalar;
253 typedef RhsScalar_ RhsScalar;
266 : m_blockA(0), m_blockB(0), m_mc(0), m_nc(0), m_kc(0)
269 inline Index mc()
const {
return m_mc; }
270 inline Index nc()
const {
return m_nc; }
271 inline Index kc()
const {
return m_kc; }
273 inline LhsScalar* blockA() {
return m_blockA; }
274 inline RhsScalar* blockB() {
return m_blockB; }
277 template<
int StorageOrder,
typename LhsScalar_,
typename RhsScalar_,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor>
278 class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, MaxDepth, KcFactor, true >
279 :
public level3_blocking<
280 std::conditional_t<StorageOrder==RowMajor,RhsScalar_,LhsScalar_>,
281 std::conditional_t<StorageOrder==RowMajor,LhsScalar_,RhsScalar_>>
285 ActualRows = Transpose ? MaxCols : MaxRows,
286 ActualCols = Transpose ? MaxRows : MaxCols
288 typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
289 typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
291 SizeA = ActualRows * MaxDepth,
292 SizeB = ActualCols * MaxDepth
295 #if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
307 this->m_mc = ActualRows;
308 this->m_nc = ActualCols;
309 this->m_kc = MaxDepth;
310 #if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
311 this->m_blockA = m_staticA;
312 this->m_blockB = m_staticB;
322 inline void allocateA() {}
323 inline void allocateB() {}
324 inline void allocateAll() {}
327 template<
int StorageOrder,
typename LhsScalar_,
typename RhsScalar_,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor>
328 class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, MaxDepth, KcFactor, false>
329 :
public level3_blocking<
330 std::conditional_t<StorageOrder==RowMajor,RhsScalar_,LhsScalar_>,
331 std::conditional_t<StorageOrder==RowMajor,LhsScalar_,RhsScalar_>>
336 typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
337 typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
346 this->m_mc = Transpose ?
cols :
rows;
347 this->m_nc = Transpose ?
rows :
cols;
352 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc, num_threads);
357 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc,
n, num_threads);
360 m_sizeA = this->m_mc * this->m_kc;
361 m_sizeB = this->m_kc * this->m_nc;
366 this->m_mc = Transpose ?
cols :
rows;
367 this->m_nc = Transpose ?
rows :
cols;
372 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc,
m, this->m_nc, num_threads);
373 m_sizeA = this->m_mc * this->m_kc;
374 m_sizeB = this->m_kc * this->m_nc;
379 if(this->m_blockA==0)
380 this->m_blockA = aligned_new<LhsScalar>(m_sizeA);
385 if(this->m_blockB==0)
386 this->m_blockB = aligned_new<RhsScalar>(m_sizeB);
395 ~gemm_blocking_space()
406 template<
typename Lhs,
typename Rhs>
407 struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,
GemmProduct>
408 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
410 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
411 typedef typename Lhs::Scalar LhsScalar;
412 typedef typename Rhs::Scalar RhsScalar;
414 typedef internal::blas_traits<Lhs> LhsBlasTraits;
415 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
416 typedef internal::remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
418 typedef internal::blas_traits<Rhs> RhsBlasTraits;
419 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
420 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
426 typedef generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> lazyproduct;
428 template<
typename Dst>
429 static void evalTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
438 lazyproduct::eval_dynamic(dst, lhs, rhs, internal::assign_op<typename Dst::Scalar,Scalar>());
442 scaleAndAddTo(dst, lhs, rhs, Scalar(1));
446 template<
typename Dst>
447 static void addTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
450 lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar,Scalar>());
452 scaleAndAddTo(dst,lhs, rhs, Scalar(1));
455 template<
typename Dst>
456 static void subTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
459 lazyproduct::eval_dynamic(dst, lhs, rhs, internal::sub_assign_op<typename Dst::Scalar,Scalar>());
461 scaleAndAddTo(dst, lhs, rhs, Scalar(-1));
464 template<
typename Dest>
465 static void scaleAndAddTo(Dest& dst,
const Lhs& a_lhs,
const Rhs& a_rhs,
const Scalar& alpha)
467 eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
468 if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
474 typename Dest::ColXpr dst_vec(dst.col(0));
475 return internal::generic_product_impl<Lhs,typename Rhs::ConstColXpr,DenseShape,DenseShape,GemvProduct>
476 ::scaleAndAddTo(dst_vec, a_lhs, a_rhs.col(0), alpha);
478 else if (dst.rows() == 1)
481 typename Dest::RowXpr dst_vec(dst.row(0));
482 return internal::generic_product_impl<typename Lhs::ConstRowXpr,Rhs,DenseShape,DenseShape,GemvProduct>
483 ::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);
486 add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
487 add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
492 Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
494 typedef internal::gemm_functor<
496 internal::general_matrix_matrix_product<
501 Dest::InnerStrideAtCompileTime>,
502 ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
504 BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1,
true);
506 (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&
RowMajorBit);
RowXpr row(Index i)
This is the const version of row(). */.
ColXpr col(Index i)
This is the const version of col().
#define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define eigen_internal_assert(x)
#define EIGEN_UNUSED_VARIABLE(var)
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
const unsigned int RowMajorBit
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
void parallelize_gemm(const Functor &func, Index rows, Index cols, Index depth, bool transpose)
EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar &alpha, const Lhs &lhs, const Rhs &rhs)
void aligned_delete(T *ptr, std::size_t size)
constexpr int min_size_prefer_fixed(A a, B b)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.