10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
26 template<
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
typename StorageIndex,
int ShardingType = ShardByCol>
27 class TensorContractionBlocking {
44 #if !defined(EIGEN_HIPCC)
47 TensorContractionBlocking(StorageIndex k, StorageIndex
m, StorageIndex
n, StorageIndex num_threads = 1) :
48 kc_(k), mc_(
m), nc_(
n)
51 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
54 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
57 const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
58 kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ?
59 kc_ : (kc_ / rhs_packet_size) * rhs_packet_size;
#define EIGEN_ALWAYS_INLINE
#define EIGEN_DEVICE_FUNC
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend