10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
26 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
27 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >
30 typedef typename gebp_traits<std::remove_const_t<typename LhsXprType::Scalar>,
31 std::remove_const_t<typename RhsXprType::Scalar>>::ResScalar Scalar;
33 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
35 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
37 typedef typename LhsXprType::Nested LhsNested;
38 typedef typename RhsXprType::Nested RhsNested;
39 typedef std::remove_reference_t<LhsNested> LhsNested_;
40 typedef std::remove_reference_t<RhsNested> RhsNested_;
45 typedef std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
55 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
56 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>,
Eigen::Dense>
58 typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type;
61 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
62 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >::type>
64 typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type;
67 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename OutputKernelType_,
typename Device_>
68 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_> > {
69 typedef Indices_ Indices;
70 typedef LeftArgType_ LeftArgType;
71 typedef RightArgType_ RightArgType;
72 typedef OutputKernelType_ OutputKernelType;
73 typedef Device_ Device;
80 template <
typename LhsScalar,
typename RhsScalar>
81 struct TensorContractionBlockMemAllocator {
82 typedef void* BlockMemHandle;
84 template <
typename Device>
88 LhsScalar** lhs_block,
89 RhsScalar** rhs_block) {
92 BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
93 char* block_mem =
static_cast<char*
>(d.allocate(sz.lhs_size + sz.rhs_size));
94 *lhs_block =
static_cast<LhsScalar*
>(
static_cast<void*
>(block_mem));
95 *rhs_block =
static_cast<RhsScalar*
>(
static_cast<void*
>(block_mem + sz.lhs_size));
99 template <
typename Device>
103 std::vector<LhsScalar*>* lhs_blocks,
104 std::vector<RhsScalar*>* rhs_blocks) {
109 BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
110 void* block_mem = d.allocate(
111 (num_lhs * sz.lhs_size + num_rhs * sz.rhs_size) * num_slices);
113 char* mem =
static_cast<char*
>(block_mem);
115 for (
Index x = 0;
x < num_slices;
x++) {
116 if (num_lhs > 0) lhs_blocks[
x].resize(num_lhs);
117 for (
Index m = 0;
m < num_lhs;
m++) {
118 lhs_blocks[
x][
m] =
static_cast<LhsScalar*
>(
static_cast<void*
>(mem));
121 if (num_rhs > 0) rhs_blocks[
x].resize(num_rhs);
122 for (
Index n = 0;
n < num_rhs;
n++) {
123 rhs_blocks[
x][
n] =
static_cast<RhsScalar*
>(
static_cast<void*
>(mem));
131 template <
typename Device>
133 d.deallocate(handle);
146 sz.lhs_size = divup<Index>(bm * bk *
sizeof(LhsScalar), align) * align;
147 sz.rhs_size = divup<Index>(bn * bk *
sizeof(RhsScalar), align) * align;
180 template <
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
181 typename StorageIndex,
typename OutputMapper,
typename LhsMapper,
183 struct TensorContractionKernel {
186 enum { HasBeta =
false };
189 TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_,
190 StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
191 :
m(m_), k(k_),
n(n_), bm(bm_), bk(bk_), bn(bn_) {}
194 typedef LhsScalar* LhsBlock;
195 typedef RhsScalar* RhsBlock;
198 typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar>
200 typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle;
202 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
204 typedef internal::gemm_pack_lhs<
205 LhsScalar, StorageIndex,
typename LhsMapper::SubMapper, Traits::mr,
206 Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
ColMajor>
209 typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex,
210 typename RhsMapper::SubMapper, Traits::nr,
214 typedef internal::gebp_kernel<LhsScalar, RhsScalar, StorageIndex,
215 OutputMapper, Traits::mr, Traits::nr,
219 template <
typename Device>
221 RhsBlock* rhs_block) {
222 return BlockMemAllocator::allocate(d, bm, bk, bn, lhs_block, rhs_block);
225 template <
typename Device>
227 Device& d,
const StorageIndex num_lhs,
const StorageIndex num_rhs,
228 const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks,
229 std::vector<RhsBlock>* rhs_blocks) {
230 return BlockMemAllocator::allocateSlices(
231 d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks);
234 template <
typename Device>
236 BlockMemAllocator::deallocate(d, handle);
240 LhsBlock* lhsBlock,
const typename LhsMapper::SubMapper& data_mapper,
241 const StorageIndex depth,
const StorageIndex rows) {
242 LhsPacker()(*lhsBlock, data_mapper, depth,
rows, 0,
247 RhsBlock* rhsBlock,
const typename RhsMapper::SubMapper& data_mapper,
248 const StorageIndex depth,
const StorageIndex cols) {
249 RhsPacker()(*rhsBlock, data_mapper, depth,
cols);
253 const OutputMapper& output_mapper,
const LhsBlock& lhsBlock,
254 const RhsBlock& rhsBlock,
const StorageIndex rows,
255 const StorageIndex depth,
const StorageIndex cols,
256 const ResScalar alpha,
const ResScalar beta) {
259 static const int kComputeStrideFromBlockDimensions = -1;
260 GebpKernel()(output_mapper, lhsBlock, rhsBlock,
rows, depth,
cols, alpha,
261 kComputeStrideFromBlockDimensions,
262 kComputeStrideFromBlockDimensions,
270 const StorageIndex
m;
271 const StorageIndex k;
272 const StorageIndex
n;
273 const StorageIndex bm;
274 const StorageIndex bk;
275 const StorageIndex bn;
311 template <
typename Index,
typename Scalar>
313 const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
325 template<
typename Indices,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType = const NoOpOutputKernel>
329 typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar
Scalar;
330 typedef typename internal::gebp_traits<
typename LhsXprType::CoeffReturnType,
332 typedef typename Eigen::internal::nested<TensorContractionOp>::type
Nested;
333 typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind
StorageKind;
334 typedef typename Eigen::internal::traits<TensorContractionOp>::Index
Index;
337 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims,
338 const OutputKernelType& output_kernel = OutputKernelType())
364 template<
typename Derived>
366 typedef typename internal::traits<Derived>::Indices
Indices;
367 typedef typename internal::traits<Derived>::LeftArgType
LeftArgType;
370 typedef typename internal::traits<Derived>::Device
Device;
373 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
398 typedef std::conditional_t<
400 typedef std::conditional_t<
407 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
409 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
410 static constexpr
int ContractDims = internal::array_size<Indices>::value;
422 op.lhsExpression(), op.rhsExpression()), device),
424 op.rhsExpression(), op.lhsExpression()), device),
430 YOU_MADE_A_PROGRAMMING_MISTAKE);
446 eval_op_indices[
i].first = op.
indices()[
i].first;
447 eval_op_indices[
i].second = op.
indices()[
i].second;
470 eval_op_indices[
j].second != eval_op_indices[
i].second &&
471 "contraction axes should be unique");
472 if (eval_op_indices[
j].
first < eval_op_indices[
i].
first) {
481 lhs_strides[
i+1] = lhs_strides[
i] * eval_left_dims[
i];
487 rhs_strides[
i+1] = rhs_strides[
i] * eval_right_dims[
i];
504 Index nocontract_idx = 0;
508 bool contracting =
false;
510 if (eval_op_indices[
j].
first ==
i) {
522 if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) {
535 bool contracting =
false;
538 if (eval_op_indices[
j].second ==
i) {
545 if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) {
565 Index left = eval_op_indices[
i].first;
566 Index right = eval_op_indices[
i].second;
570 "Contraction axes must be same size");
572 if (
i+1 <
static_cast<int>(internal::array_size<contract_t>::value)) {
580 if (
i > 0 && right < eval_op_indices[
i-1].second) {
617 #ifdef EIGEN_USE_THREADS
618 template <
typename EvalSubExprsCallback>
619 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(
621 m_leftImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
622 m_rightImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
624 evalToAsync(dest, [done]() { done(
false); });
628 evalToAsync(
m_result, [done]() { done(
true); });
635 #ifndef TENSOR_CONTRACTION_DISPATCH
636 #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
637 if (this->m_lhs_inner_dim_contiguous) { \
638 if (this->m_rhs_inner_dim_contiguous) { \
639 if (this->m_rhs_inner_dim_reordered) { \
640 METHOD<true, true, true, ALIGNMENT> ARGS; \
642 METHOD<true, true, false, ALIGNMENT> ARGS; \
645 if (this->m_rhs_inner_dim_reordered) { \
646 METHOD<true, false, true, ALIGNMENT> ARGS; \
648 METHOD<true, false, false, ALIGNMENT> ARGS; \
652 if (this->m_rhs_inner_dim_contiguous) { \
653 if (this->m_rhs_inner_dim_reordered) { \
654 METHOD<false, true, true, ALIGNMENT> ARGS; \
656 METHOD<false, true, false, ALIGNMENT> ARGS; \
659 if (this->m_rhs_inner_dim_reordered) { \
660 METHOD<false, false, true, ALIGNMENT> ARGS; \
662 METHOD<false, false, false, ALIGNMENT> ARGS; \
668 #ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
669 #define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
670 if (this->m_lhs_inner_dim_contiguous) { \
671 if (this->m_rhs_inner_dim_contiguous) { \
672 if (this->m_rhs_inner_dim_reordered) { \
673 (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \
675 (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \
678 if (this->m_rhs_inner_dim_reordered) { \
679 (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \
681 (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \
685 if (this->m_rhs_inner_dim_contiguous) { \
686 if (this->m_rhs_inner_dim_reordered) { \
687 (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \
689 (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \
692 if (this->m_rhs_inner_dim_reordered) { \
693 (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \
695 (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \
702 static_cast<const Derived*
>(
this)->
template evalProduct<Unaligned>(buffer);
705 #ifdef EIGEN_USE_THREADS
706 template <
typename EvalToCallback>
707 void evalToAsync(
Scalar* buffer, EvalToCallback done)
const {
708 static_cast<const Derived*
>(
this)
709 ->
template evalProductAsync<EvalToCallback, Unaligned>(buffer,
714 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
715 bool rhs_inner_dim_reordered,
int Alignment>
718 this->
template evalGemv<lhs_inner_dim_contiguous,
719 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
722 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
723 rhs_inner_dim_reordered, Alignment>(buffer);
727 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
728 #if !defined(EIGEN_HIPCC)
735 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
736 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
739 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
740 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
746 lhs_inner_dim_contiguous,
747 false, lhs_alignment> LhsMapper;
752 rhs_inner_dim_contiguous,
753 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
761 const Index resIncr(1);
766 internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
768 buffer, resIncr, alpha);
770 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
773 static_cast<Index>(1));
776 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
777 #if !defined(EIGEN_HIPCC)
784 rhs_inner_dim_contiguous,
785 rhs_inner_dim_reordered,
786 Alignment,
true>(buffer, 0, k, 1);
789 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
790 bool rhs_inner_dim_reordered,
int Alignment>
794 rhs_inner_dim_reordered, Alignment,
795 false>(buffer, k_start, k_end,
799 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment,
bool use_output_kernel>
803 const Index k_slice = k_end - k_start;
812 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
813 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
818 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
819 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
824 lhs_inner_dim_contiguous,
830 rhs_inner_dim_contiguous,
831 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
833 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
835 typedef internal::TensorContractionKernel<
836 Scalar, LhsScalar, RhsScalar,
Index, OutputMapper, LhsMapper, RhsMapper>
837 TensorContractionKernel;
846 OutputMapper output(buffer,
m);
849 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
851 blocking(k_slice,
m,
n, num_threads);
852 const Index kc = blocking.kc();
856 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
857 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
862 TensorContractionKernel kernel(
m, k_slice,
n, mc, kc, nc);
864 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
865 const BlockMemHandle packed_mem =
866 kernel.allocate(this->
m_device, &blockA, &blockB);
870 if (!TensorContractionKernel::HasBeta) {
874 for(
Index i2=0; i2<
m; i2+=mc)
877 for (
Index k2 = k_start; k2 < k_end; k2 += kc) {
880 kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
885 const Scalar beta = (TensorContractionKernel::HasBeta && k2 == k_start)
890 for (
Index j2 = 0; j2 <
n; j2 += nc) {
893 kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc,
898 const OutputMapper output_mapper = output.getSubMapper(i2, j2);
899 kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
900 actual_nc, alpha, beta);
903 if (use_output_kernel && k2 + kc >= k_end) {
905 actual_mc, actual_nc);
911 kernel.deallocate(this->
m_device, packed_mem);
932 template<
int LoadMode>
934 return internal::ploadt<PacketReturnType, LoadMode>(
m_result + index);
970 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType,
typename Device>
973 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > {
978 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
992 static constexpr
int LDims =
993 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
994 static constexpr
int RDims =
995 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
996 static constexpr
int ContractDims = internal::array_size<Indices>::value;
1002 static constexpr
int NumDims = LDims + RDims - 2 * ContractDims;
1008 Base(op, device) { }
1010 template <
int Alignment>
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_DEVICE_FUNC
#define EIGEN_DONT_INLINE
#define EIGEN_STATIC_ASSERT(X, MSG)
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS)
const Indices & indices() const
const OutputKernelType m_output_kernel
Eigen::internal::traits< TensorContractionOp >::Index Index
Eigen::internal::nested< TensorContractionOp >::type Nested
TensorContractionOp(const LhsXprType &lhs, const RhsXprType &rhs, const Indices &dims, const OutputKernelType &output_kernel=OutputKernelType())
Eigen::internal::traits< TensorContractionOp >::StorageKind StorageKind
const internal::remove_all_t< typename LhsXprType::Nested > & lhsExpression() const
internal::gebp_traits< typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType >::ResScalar CoeffReturnType
const OutputKernelType & outputKernel() const
Eigen::internal::traits< TensorContractionOp >::Scalar Scalar
const internal::remove_all_t< typename RhsXprType::Nested > & rhsExpression() const
LhsXprType::Nested m_lhs_xpr
RhsXprType::Nested m_rhs_xpr
typename remove_all< T >::type remove_all_t
EIGEN_CONSTEXPR Index first(const T &x) EIGEN_NOEXCEPT
EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
EIGEN_ALWAYS_INLINE const T1 & choose(Cond< true >, const T1 &first, const T2 &)
DenseIndex TotalSize() const
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper< Scalar, Index, ColMajor > &output_mapper, const TensorContractionParams ¶ms, Index i, Index j, Index num_rows, Index num_cols) const
internal::packet_traits< Scalar >::type type
XprType::CoeffReturnType CoeffReturnType
static constexpr int Layout
TensorEvaluator< EvalRightArgType, Device > RightEvaluatorType
DSizes< Index, NumDims > Dimensions
static constexpr int NumDims
void evalTo(Scalar *buffer) const
void evalGemv(Scalar *buffer) const
static constexpr int LDims
StorageMemory< Scalar, Device > Storage
right_nocontract_t m_j_strides
internal::traits< Derived >::Device Device
right_nocontract_t m_right_nocontract_strides
EvaluatorPointerType data() const
internal::traits< Derived >::LeftArgType LeftArgType
contract_t m_right_contracting_strides
const Device EIGEN_DEVICE_REF m_device
static constexpr int RDims
PacketReturnType packet(Index index) const
array< Index, RDims - ContractDims > right_nocontract_t
TensorEvaluator< EvalLeftArgType, Device > LeftEvaluatorType
left_nocontract_t m_left_nocontract_strides
bool evalSubExprsIfNeeded(EvaluatorPointerType data)
EvaluatorPointerType m_result
Storage::Type EvaluatorPointerType
void evalGemmPartialWithoutOutputKernel(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
bool m_rhs_inner_dim_reordered
std::conditional_t< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType > EvalLeftArgType
internal::traits< Derived >::RightArgType RightArgType
bool m_rhs_inner_dim_contiguous
contract_t m_left_contracting_strides
static constexpr int ContractDims
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
std::remove_const_t< typename XprType::Scalar > Scalar
CoeffReturnType coeff(Index index) const
std::conditional_t< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType > EvalRightArgType
TensorEvaluator< EvalRightArgType, Device > m_rightImpl
TensorContractionEvaluatorBase(const XprType &op, const Device &device)
const Dimensions & dimensions() const
void evalGemmPartial(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
internal::TensorBlockNotImplemented TensorBlock
void evalProductSequential(Scalar *buffer) const
OutputKernelType m_output_kernel
TensorContractionParams m_tensor_contraction_params
TensorOpCost costPerCoeff(bool) const
bool m_lhs_inner_dim_contiguous
array< Index, ContractDims > contract_t
internal::traits< Derived >::OutputKernelType OutputKernelType
TensorEvaluator< EvalLeftArgType, Device > m_leftImpl
PacketType< CoeffReturnType, Device >::type PacketReturnType
array< Index, LDims - ContractDims > left_nocontract_t
void evalGemm(Scalar *buffer) const
internal::traits< Derived >::Indices Indices
left_nocontract_t m_i_strides
PacketType< CoeffReturnType, Device >::type PacketReturnType
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
std::conditional_t< Layout==static_cast< int >ColMajor), RightArgType, LeftArgType > EvalRightArgType
TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Device > Self
DSizes< Index, NumDims > Dimensions
array< Index, LDims - ContractDims > left_nocontract_t
TensorEvaluator(const XprType &op, const Device &device)
XprType::CoeffReturnType CoeffReturnType
std::remove_const_t< typename XprType::Scalar > Scalar
TensorContractionEvaluatorBase< Self > Base
std::conditional_t< Layout==static_cast< int >ColMajor), LeftArgType, RightArgType > EvalLeftArgType
void evalProduct(Scalar *buffer) const
array< Index, ContractDims > contract_t
array< Index, RDims - ContractDims > right_nocontract_t
A cost model used to limit the number of threads used for evaluating tensor expression.
const Dimensions & dimensions() const
static constexpr int Layout
bool evalSubExprsIfNeeded(EvaluatorPointerType dest)