19 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
20 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
26 namespace TensorSycl {
29 #ifndef EIGEN_SYCL_DISABLE_GEMV
44 template <
typename Scalar,
typename StorageIndex, StorageIndex NCWindow, StorageIndex CFactor, StorageIndex NCFactor>
47 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeC = EIGEN_SYCL_LOCAL_THREAD_DIM0;
49 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC = EIGEN_SYCL_LOCAL_THREAD_DIM1;
51 static EIGEN_CONSTEXPR StorageIndex TileSizeDimNC = NCWindow / NCFactor;
53 static EIGEN_CONSTEXPR StorageIndex TileSizeDimC = CFactor * LocalThreadSizeNC * LocalThreadSizeC;
55 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC = TileSizeDimNC / LocalThreadSizeNC;
57 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadC = TileSizeDimC / LocalThreadSizeC;
80 template <
typename Scalar,
typename StorageIndex, StorageIndex REG_SIZE_M, StorageIndex REG_SIZE_N, StorageIndex TSDK>
86 #ifndef EIGEN_SYCL_REG_M
89 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = EIGEN_SYCL_REG_M;
93 #ifndef EIGEN_SYCL_REG_N
96 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = EIGEN_SYCL_REG_N;
99 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeM = EIGEN_SYCL_LOCAL_THREAD_DIM0;
101 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeN = EIGEN_SYCL_LOCAL_THREAD_DIM1;
103 static EIGEN_CONSTEXPR StorageIndex TileSizeDimM = LocalThreadSizeM * WorkLoadPerThreadM;
105 static EIGEN_CONSTEXPR StorageIndex TileSizeDimN = LocalThreadSizeN * WorkLoadPerThreadN;
108 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimN));
111 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimM));
117 #ifdef EIGEN_SYCL_DISABLE_DOUBLE_BUFFER
159 template <
bool PacketLoad,
bool is_coalesced_layout, bool,
typename PacketType,
typename TensorMapper,
160 typename StorageIndex>
162 const TensorMapper &tensorMapper,
const StorageIndex &NCIndex,
const StorageIndex &CIndex,
const StorageIndex &ld) {
163 const StorageIndex
row = (is_coalesced_layout) ? NCIndex : CIndex;
164 const StorageIndex
col = (is_coalesced_layout) ? CIndex : NCIndex;
165 return tensorMapper.get_tensor().template packet<Unaligned>(
row + (
col * ld));
190 template <
bool PacketLoad,
bool,
bool IsRhs,
typename PacketType,
typename TensorMapper,
typename StorageIndex>
192 const TensorMapper &tensorMapper,
const StorageIndex &NCIndex,
const StorageIndex &CIndex,
const StorageIndex &) {
193 const StorageIndex
row = (IsRhs) ? CIndex : NCIndex;
194 const StorageIndex
col = (IsRhs) ? NCIndex : CIndex;
195 return tensorMapper(
row,
col);
219 template <
typename StorageIndex, StorageIndex ld, data_source dt,
typename PacketType,
typename DataScalar>
221 std::enable_if_t<dt != data_source::global_mem, void>
223 EIGEN_CONSTEXPR int PacketSize = Eigen::internal::unpacket_traits<PacketType>::size;
225 for (
int i = 0;
i < PacketSize;
i++) {
226 *ptr = PacketWrapper<PacketType, PacketSize>::scalarize(
i, packet_data);
246 template <data_source dt,
typename PacketType,
typename DataScalar>
250 ::Eigen::internal::pstoreu<DataScalar, PacketType>(ptr, packet_data);
266 template <data_source dt,
typename PacketType,
typename DataScalar>
278 template <
bool is_
internal>
319 template <
bool is_transposed,
bool is_rhs_,
bool packet_load_,
typename PacketType>
320 struct BlockProperties {
322 typedef typename Eigen::internal::unpacket_traits<PacketType>::type OutScalar;
324 typedef std::conditional_t<packet_load, PacketType, OutScalar> OutType;
325 static EIGEN_CONSTEXPR int elements_per_access = Eigen::internal::unpacket_traits<OutType>::size;
326 static EIGEN_CONSTEXPR bool is_coalesced_layout = !(is_transposed ^ is_rhs);
327 static EIGEN_CONSTEXPR int nc_stride = (is_coalesced_layout ? elements_per_access : 1);
328 static EIGEN_CONSTEXPR int c_stride = (is_coalesced_layout ? 1 : elements_per_access);
370 template <
typename StorageIndex>
371 struct ThreadProperties {
372 const StorageIndex linearLocalThreadId;
373 const StorageIndex kGroupId;
374 const StorageIndex mGroupOffset;
375 const StorageIndex nGroupOffset;
376 const StorageIndex kGroupOffset;
377 const StorageIndex mLocalOffset;
378 const StorageIndex nLocalOffset;
379 const StorageIndex mGlobalOffset;
380 const StorageIndex nGlobalOffset;
382 const bool is_internal;
385 const StorageIndex linearLocalThreadId_,
const StorageIndex kGroupId_,
const StorageIndex mGroupOffset_,
386 const StorageIndex nGroupOffset_,
const StorageIndex kGroupOffset_,
const StorageIndex mLocalOffset_,
387 const StorageIndex nLocalOffset_,
const StorageIndex mGlobalOffset_,
const StorageIndex nGlobalOffset_,
388 StorageIndex kSize_,
const bool is_internal_)
389 : linearLocalThreadId(linearLocalThreadId_),
391 mGroupOffset(mGroupOffset_),
392 nGroupOffset(nGroupOffset_),
393 kGroupOffset(kGroupOffset_),
394 mLocalOffset(mLocalOffset_),
395 nLocalOffset(nLocalOffset_),
396 mGlobalOffset(mGlobalOffset_),
397 nGlobalOffset(nGlobalOffset_),
399 is_internal(is_internal_) {}
452 template <
typename OutScalar,
typename LhsScalar,
typename RhsScalar,
typename OutAccessor,
typename LhsMapper,
453 typename RhsMapper,
typename StorageIndex,
typename Properties,
typename TripleDim,
bool Vectorizable,
454 typename input_mapper_properties,
bool IsFinal,
contraction_type contraction_tp>
455 class TensorContractionKernel {
457 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
460 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
462 !::Eigen::internal::TensorContractionInputMapperTrait<LhsMapper>::inner_dim_contiguous;
464 !::Eigen::internal::TensorContractionInputMapperTrait<RhsMapper>::inner_dim_contiguous;
466 typedef BlockProperties<is_lhs_transposed,
false, input_mapper_properties::is_lhs_matrix && Vectorizable,
470 typedef BlockProperties<is_rhs_transposed,
true, input_mapper_properties::is_rhs_matrix && Vectorizable,
477 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
478 typedef cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::local_space> local_ptr;
479 typedef OutScalar * private_ptr;
480 typedef std::conditional_t<contraction_tp == contraction_type::local, local_ptr, private_ptr>
483 ? Properties::TileSizeDimM + Properties::BC
484 : Properties::WorkLoadPerThreadM;
486 ? Properties::TileSizeDimN + Properties::BC
487 : Properties::WorkLoadPerThreadN;
488 static EIGEN_CONSTEXPR StorageIndex LocalOffset = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
502 template <contraction_type, StorageIndex>
505 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MemHolder(local_ptr block_start_ptr) : ptr(block_start_ptr) {}
510 template <StorageIndex MemSize>
512 OutScalar ptr[MemSize] = {OutScalar{0}};
537 MemHolder<contraction_tp, Properties::WorkLoadPerThreadM * Properties::TileSizeDimK> lhs_scratch_extract;
538 MemHolder<contraction_tp, Properties::WorkLoadPerThreadN * Properties::TileSizeDimK> rhs_scratch_extract;
539 tile_ptr lhs_scratch_ptr_compute;
540 tile_ptr rhs_scratch_ptr_compute;
541 const std::pair<StorageIndex, StorageIndex> lhs_extract_index;
542 const std::pair<StorageIndex, StorageIndex> rhs_extract_index;
543 template <contraction_type tp = contraction_tp>
545 TiledMemory(
const ThreadProperties<StorageIndex> &, local_ptr,
546 std::enable_if_t<tp == contraction_type::no_local> * = 0)
547 : lhs_scratch_extract{},
548 rhs_scratch_extract{},
549 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr),
550 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr),
551 lhs_extract_index(
std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})),
552 rhs_extract_index(
std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})) {}
554 template <contraction_type tp = contraction_tp>
556 TiledMemory(
const ThreadProperties<StorageIndex> &thread_properties, local_ptr block_start_ptr,
557 std::enable_if_t<tp == contraction_type::local> * = 0)
558 : lhs_scratch_extract{block_start_ptr},
559 rhs_scratch_extract{lhs_scratch_extract.ptr +
560 ((Properties::DoubleBuffer + 1) * LSDL * Properties::TileSizeDimK)},
561 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr + thread_properties.mLocalOffset),
562 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr + thread_properties.nLocalOffset),
564 local_id_extract<LHSBlockProperties, Properties::TileSizeDimM>(thread_properties.linearLocalThreadId)),
566 local_id_extract<RHSBlockProperties, Properties::TileSizeDimN>(thread_properties.linearLocalThreadId)) {}
573 const StorageIndex groupSizeM;
574 const StorageIndex groupSizeN;
575 const StorageIndex numTiles;
576 const TripleDim triple_dim;
578 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_,
const LhsMapper lhs_,
579 const RhsMapper rhs_, OutAccessor out_res_,
580 const StorageIndex groupSizeM_,
581 const StorageIndex groupSizeN_,
582 const StorageIndex numTiles_,
583 const TripleDim triple_dim_)
588 groupSizeM(groupSizeM_),
589 groupSizeN(groupSizeN_),
591 triple_dim(triple_dim_) {}
593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_,
const LhsMapper lhs_,
594 const RhsMapper rhs_, OutAccessor out_res_,
595 const StorageIndex groupSizeM_,
596 const StorageIndex numTiles_,
597 const TripleDim triple_dim_)
598 : TensorContractionKernel(scratch_, lhs_, rhs_, out_res_, groupSizeM_, 1, numTiles_, triple_dim_) {}
601 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
602 const StorageIndex nLocalThreadId = linearLocalThreadId / Properties::LocalThreadSizeM;
603 const StorageIndex mLocalThreadId = linearLocalThreadId % Properties::LocalThreadSizeM;
604 const StorageIndex mGroupId = itemID.get_group(0) % groupSizeM;
605 const StorageIndex tmp = itemID.get_group(0) / groupSizeM;
606 const StorageIndex nGroupId = IsFinal ? tmp : tmp % groupSizeN;
607 const StorageIndex kGroupId = IsFinal ? 0 : tmp / groupSizeN;
608 const StorageIndex mGroupOffset = mGroupId * Properties::TileSizeDimM;
609 const StorageIndex nGroupOffset = nGroupId * Properties::TileSizeDimN;
610 const StorageIndex mLocalOffset = PacketSize * mLocalThreadId;
611 const StorageIndex nLocalOffset = NStride * nLocalThreadId;
612 const StorageIndex mGlobalOffset = mGroupOffset + mLocalOffset;
613 const StorageIndex nGlobalOffset = nGroupOffset + nLocalOffset;
615 const StorageIndex kSizePerWG = IsFinal ? triple_dim.K : numTiles * Properties::TileSizeDimK;
616 StorageIndex kGroupOffset = kGroupId * kSizePerWG;
617 const bool is_internal = triple_dim.M - mGroupOffset >= Properties::TileSizeDimM &&
618 triple_dim.N - nGroupOffset >= Properties::TileSizeDimN &&
619 triple_dim.K - kGroupOffset >= kSizePerWG;
621 StorageIndex kSize = IsFinal ? triple_dim.K :
std::min(kSizePerWG, triple_dim.K - kGroupOffset);
624 kGroupOffset += kSize;
626 auto thread_properties =
627 ThreadProperties<StorageIndex>(linearLocalThreadId, kGroupId, mGroupOffset, nGroupOffset, kGroupOffset,
628 mLocalOffset, nLocalOffset, mGlobalOffset, nGlobalOffset, kSize, is_internal);
630 auto out_ptr = out_res + (IsFinal ? 0 : thread_properties.kGroupId * triple_dim.M * triple_dim.N);
632 (thread_properties.is_internal) ? compute_panel<true>(itemID, thread_properties, out_ptr)
633 : compute_panel<false>(itemID, thread_properties, out_ptr);
638 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void compute_block_per_tile(OutScalar *lhs_block_ptr, OutScalar *rhs_block_ptr,
639 PacketReturnType *privateRes)
const {
640 StorageIndex idx = 0;
644 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN; wLPTN++) {
645 auto rhsPacket = PacketReturnType{*(rhs_block_ptr + wLPTN)};
646 StorageIndex lhs_index = 0;
648 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
649 PacketReturnType lhsPack{};
650 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::set_packet(lhsPack,
651 lhs_block_ptr + lhs_index);
652 privateRes[idx] = ::Eigen::internal::pmadd(lhsPack, rhsPacket, privateRes[idx]);
654 lhs_index += lhs_stride;
662 template <
bool is_
internal_block, StorageIndex PrivateNStr
ide,
typename OutPtr>
663 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void store(OutPtr *out_ptr, PacketReturnType *privateRes,
664 StorageIndex mGlobalOffset, StorageIndex nGlobalOffset)
const {
665 auto chk_bound = [&](
const StorageIndex &mIndex,
const StorageIndex &nIndex) EIGEN_DEVICE_FUNC {
666 return (mIndex + PacketSize - 1 < triple_dim.M && nGlobalOffset + nIndex < triple_dim.N);
674 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN / PrivateNStride; wLPTN++) {
676 StorageIndex outputLD = 0;
681 for (StorageIndex nId = 0; nId < PrivateNStride; nId++) {
682 StorageIndex globalRow = mGlobalOffset;
684 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
685 PacketReturnType privetOut = privateRes[wLPTM];
686 if (check_boundary<is_internal_block>(chk_bound(globalRow, nId))) {
689 write<data_source::global_mem>(privetOut, out_ptr + outputLD + globalRow);
692 for (StorageIndex mId = 0; mId < PacketSize; mId++) {
693 StorageIndex mOffset = globalRow + mId;
694 if (mOffset < triple_dim.M && (nGlobalOffset + nId < triple_dim.N)) {
695 out_ptr[mOffset + outputLD] =
696 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::scalarize(mId, privetOut);
700 globalRow += (PacketSize * Properties::LocalThreadSizeM);
702 outputLD += triple_dim.M;
703 privateRes += Properties::WorkLoadPerThreadM / PacketSize;
705 out_ptr += (GlobalNStride * outputLD);
707 nGlobalOffset += (PrivateNStride * GlobalNStride);
711 template <
typename InputBlockProperties,
bool is_internal_block,
typename Input,
typename PrivateReg,
714 std::enable_if_t<contract_tp == contraction_type::no_local>
715 extract_block(
const Input &inpt, PrivateReg private_ptr,
const std::pair<StorageIndex, StorageIndex> &,
716 const StorageIndex &ncOffset,
const StorageIndex cOffset)
const {
718 InputBlockProperties::is_rhs ? Properties::LocalThreadSizeN : Properties::LocalThreadSizeM;
720 InputBlockProperties::is_rhs ? Properties::WorkLoadPerThreadN : Properties::WorkLoadPerThreadM;
721 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
723 auto chk_bound = [&](
const StorageIndex &CIndex,
const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
724 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
725 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
727 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
728 StorageIndex cIndex = cOffset;
731 for (StorageIndex cId = 0; cId < Properties::TileSizeDimK / InputBlockProperties::c_stride; cId++) {
732 StorageIndex ncIndex = ncOffset;
734 for (StorageIndex ncId = 0; ncId < WorkLoadPerThreadNC / InputBlockProperties::nc_stride; ncId++) {
735 if (check_boundary<is_internal_block>(chk_bound(cIndex, ncIndex))) {
737 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
738 InputBlockProperties::is_rhs,
typename InputBlockProperties::OutType>(inpt, ncIndex, cIndex, ld);
740 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
744 for (StorageIndex i = 0;
i < InputBlockProperties::elements_per_access;
i++) {
745 const StorageIndex ncInd = ncIndex + (InputBlockProperties::is_coalesced_layout ?
i : 0);
746 const StorageIndex cInd = cIndex + (InputBlockProperties::is_coalesced_layout ? 0 :
i);
748 (ncInd < NC && cInd < triple_dim.K)
749 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
750 inpt, ncInd, cInd, ld)
752 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
754 val, private_ptr + (InputBlockProperties::is_coalesced_layout ? i : 0) +
755 ((InputBlockProperties::is_coalesced_layout ? 0 :
i) * WorkLoadPerThreadNC));
761 ncIndex = (!InputBlockProperties::is_rhs && InputBlockProperties::nc_stride == 1 && PacketSize != 1)
762 ? ncOffset + (ncId + 1) % PacketSize + ((ncId + 1) / PacketSize) * LocalThreadSizeNC
763 : (ncIndex + InputBlockProperties::nc_stride * LocalThreadSizeNC);
764 private_ptr += InputBlockProperties::nc_stride;
767 private_ptr += (InputBlockProperties::c_stride - 1) * WorkLoadPerThreadNC;
768 cIndex += InputBlockProperties::c_stride;
771 template <
typename InputBlockProperties, StorageIndex TileSizeDimNC>
772 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::pair<StorageIndex, StorageIndex> local_id_extract(
773 const StorageIndex &linearLocalThreadId) {
774 const StorageIndex localThreadNC =
775 (InputBlockProperties::is_coalesced_layout)
776 ? linearLocalThreadId % (TileSizeDimNC / InputBlockProperties::nc_stride)
777 : linearLocalThreadId / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
778 const StorageIndex localThreadC =
779 (InputBlockProperties::is_coalesced_layout)
780 ? linearLocalThreadId / (TileSizeDimNC / InputBlockProperties::nc_stride)
781 : linearLocalThreadId % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
782 return std::pair<StorageIndex, StorageIndex>(localThreadNC, localThreadC);
785 template <
bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
787 std::enable_if_t<db && ctp == contraction_type::local>
788 sync_mem(
const cl::sycl::nd_item<1> &,
bool &db_offset) noexcept {
789 db_offset = !db_offset;
792 template <
bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
794 std::enable_if_t<!db && ctp == contraction_type::local>
795 sync_mem(
const cl::sycl::nd_item<1> &itemID,
bool &) noexcept {
796 itemID.barrier(cl::sycl::access::fence_space::local_space);
799 template <contraction_type ctp = contraction_tp>
801 std::enable_if_t<ctp == contraction_type::no_local>
802 sync_mem(
const cl::sycl::nd_item<1> &,
bool &) noexcept {
806 template <
bool need_sync, contraction_type ctp = contraction_tp>
808 std::enable_if_t<need_sync && ctp == contraction_type::no_local>
809 sync_thread(
const cl::sycl::nd_item<1> &
810 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
814 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
815 itemID.barrier(cl::sycl::access::fence_spacce::local_space);
820 template <
bool need_sync, contraction_type ctp = contraction_tp>
822 std::enable_if_t<need_sync && ctp == contraction_type::local>
823 sync_thread(
const cl::sycl::nd_item<1> &itemID) {
824 itemID.barrier(cl::sycl::access::fence_space::local_space);
826 template <
bool need_sync>
827 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<!need_sync> sync_thread(
828 const cl::sycl::nd_item<1> &) {
832 template <
bool is_
internal_block>
833 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void compute_tile_per_panel(
const cl::sycl::nd_item<1> &itemID,
834 ThreadProperties<StorageIndex> &thread_properties,
835 TiledMemory &tiled_input_block,
836 PacketReturnType *privateRes,
bool &db_offset)
const {
839 extract_block<RHSBlockProperties, is_internal_block>(
840 rhs, tiled_input_block.rhs_scratch_extract.ptr + (db_offset * Properties::TileSizeDimK * LSDR),
841 tiled_input_block.rhs_extract_index,
843 thread_properties.kGroupOffset - thread_properties.kSize);
845 sync_thread<contraction_tp == contraction_type::no_local>(itemID);
848 extract_block<LHSBlockProperties, is_internal_block>(
849 lhs, tiled_input_block.lhs_scratch_extract.ptr + (db_offset * LSDL * Properties::TileSizeDimK),
850 tiled_input_block.lhs_extract_index,
852 thread_properties.kGroupOffset - thread_properties.kSize);
855 sync_thread<contraction_tp == contraction_type::local>(itemID);
857 StorageIndex lhs_offset = (db_offset * LSDL * Properties::TileSizeDimK);
858 StorageIndex rhs_offset = (db_offset * Properties::TileSizeDimK * LSDR);
860 for (StorageIndex k = 0; k < Properties::TileSizeDimK; k++) {
861 compute_block_per_tile(tiled_input_block.lhs_scratch_ptr_compute + lhs_offset,
862 tiled_input_block.rhs_scratch_ptr_compute + rhs_offset, privateRes);
867 thread_properties.kSize -= Properties::TileSizeDimK;
868 sync_mem(itemID, db_offset);
872 template <
bool is_
internal_block,
typename OutPtr>
873 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void compute_panel(
const cl::sycl::nd_item<1> &itemID,
874 ThreadProperties<StorageIndex> &thread_properties,
875 OutPtr out_ptr)
const {
876 auto tiled_input_block = TiledMemory{thread_properties, scratch.get_pointer()};
878 PacketReturnType privateRes[Properties::WorkLoadPerThreadM * Properties::WorkLoadPerThreadN / PacketSize] = {
879 PacketReturnType{0}};
882 while (thread_properties.kSize >= Properties::TileSizeDimK) {
883 compute_tile_per_panel<is_internal_block>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
885 if (thread_properties.kSize > 0) {
886 compute_tile_per_panel<false>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
890 store<is_internal_block,
892 out_ptr + thread_properties.nGlobalOffset * triple_dim.M, privateRes, thread_properties.mGlobalOffset,
893 thread_properties.nGlobalOffset);
896 template <
typename InputBlockProperties,
bool is_internal_block,
typename Input,
typename Local,
899 std::enable_if_t<contract_tp == contraction_type::local>
900 extract_block(
const Input &inpt, Local local_ptr,
const std::pair<StorageIndex, StorageIndex>& local_index,
901 const StorageIndex &ncOffset,
const StorageIndex cOffset)
const {
903 InputBlockProperties::is_rhs ? Properties::TileSizeDimN : Properties::TileSizeDimM;
905 InputBlockProperties::is_rhs ? Properties::LoadPerThreadRhs : Properties::LoadPerThreadLhs;
906 EIGEN_CONSTEXPR StorageIndex LSD = InputBlockProperties::is_rhs ? LSDR : LSDL;
907 static_assert(((LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) == 0) &&
908 (LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride) == 0)),
909 " LocalOffset must be divisable by stride");
910 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
911 StorageIndex localThreadNC = local_index.first;
912 StorageIndex localThreadC = local_index.second;
913 auto chk_bound = [&](
const StorageIndex &CIndex,
const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
914 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
915 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
918 for (StorageIndex lPT = 0; lPT < LoadPerThread / InputBlockProperties::elements_per_access; lPT++) {
919 const StorageIndex CIndex = cOffset + (InputBlockProperties::c_stride * localThreadC);
920 const StorageIndex NCIndex = ncOffset + (InputBlockProperties::nc_stride * localThreadNC);
921 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
922 if (check_boundary<is_internal_block>(chk_bound(CIndex, NCIndex))) {
924 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
925 InputBlockProperties::is_rhs,
typename InputBlockProperties::OutType>(inpt, NCIndex, CIndex, ld);
926 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
927 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
928 (InputBlockProperties::c_stride * localThreadC * LSD));
931 for (StorageIndex i = 0;
i < InputBlockProperties::elements_per_access;
i++) {
932 const StorageIndex nCInd = NCIndex + (InputBlockProperties::is_coalesced_layout ?
i : 0);
933 const StorageIndex cInd = CIndex + (InputBlockProperties::is_coalesced_layout ? 0 :
i);
935 (nCInd < NC && cInd < triple_dim.K)
936 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
937 inpt, nCInd, cInd, ld)
940 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
941 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
942 (InputBlockProperties::is_coalesced_layout ? i : 0) +
943 ((InputBlockProperties::c_stride * localThreadC +
944 (InputBlockProperties::is_coalesced_layout ? 0 : i)) *
948 localThreadNC += (InputBlockProperties::is_coalesced_layout)
949 ? LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride)
950 : LocalOffset / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
951 localThreadC += (InputBlockProperties::is_coalesced_layout)
952 ? LocalOffset / (TileSizeDimNC / InputBlockProperties::nc_stride)
953 : LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
958 #ifndef EIGEN_SYCL_DISABLE_GEMV
1001 template <
typename OutScalar,
typename OutAccessor,
typename VectorMapper,
typename TensorMapper,
typename StorageIndex,
1002 typename Properties, StorageIndex KFactor,
bool Vectorizable,
bool is_lhs_vec,
bool IsFinal>
1003 struct GeneralVectorTensor {
1004 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
1007 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
1008 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1011 KFactor * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1015 typedef BlockProperties<is_lhs_vec ? false : true, is_lhs_vec ? false : true, Vectorizable, PacketReturnType>
1019 const VectorMapper vec;
1020 const TensorMapper
mat;
1021 OutAccessor out_res;
1022 const StorageIndex nonContractGroupSize;
1023 const StorageIndex nonContractDim;
1024 const StorageIndex contractDim;
1026 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE GeneralVectorTensor(Scratch scratch_,
const VectorMapper vec_,
1027 const TensorMapper mat_, OutAccessor out_res_,
1028 const StorageIndex nonContractGroupSize_,
1029 const StorageIndex nonContractDim_,
1030 const StorageIndex contractDim_)
1031 : scratch(scratch_),
1035 nonContractGroupSize(nonContractGroupSize_),
1036 nonContractDim(nonContractDim_),
1037 contractDim(contractDim_) {}
1040 auto scratch_ptr = scratch.get_pointer();
1041 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
1042 StorageIndex nonContractId = is_lhs_vec ? linearLocalThreadId / Properties::LocalThreadSizeC
1043 : linearLocalThreadId % Properties::LocalThreadSizeNC;
1044 StorageIndex contractId = is_lhs_vec ? linearLocalThreadId % Properties::LocalThreadSizeC
1045 : linearLocalThreadId / Properties::LocalThreadSizeNC;
1046 const StorageIndex cGroupSize = itemID.get_group_range(0) / nonContractGroupSize;
1047 const StorageIndex nonContractGroupId =
1048 is_lhs_vec ? itemID.get_group(0) / cGroupSize : itemID.get_group(0) % nonContractGroupSize;
1049 const StorageIndex contractGroupId =
1050 is_lhs_vec ? itemID.get_group(0) % cGroupSize : itemID.get_group(0) / nonContractGroupSize;
1051 auto out_ptr = out_res + (IsFinal ? 0 : contractGroupId * nonContractDim);
1053 const StorageIndex nonContractGroupOffset = nonContractGroupId * Properties::TileSizeDimNC;
1054 const StorageIndex contractGroupOffset = contractGroupId * Properties::TileSizeDimC;
1055 auto outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1056 const StorageIndex globalNonContractDimOffset = nonContractGroupOffset + nonContractId;
1057 const StorageIndex globalContractDimOffset = contractGroupOffset + contractId;
1058 auto local_output = scratch_ptr + OutScratchOffset;
1059 const bool is_internal = nonContractDim - nonContractGroupOffset >= Properties::TileSizeDimNC &&
1060 contractDim - contractGroupOffset >= Properties::TileSizeDimC;
1062 ? compute_panel<true>(itemID, vec, mat, local_output, out_ptr,
1064 scratch_ptr, contractGroupOffset,
1066 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1067 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex)
1068 : compute_panel<false>(itemID, vec,
mat, local_output, out_ptr,
1070 scratch_ptr, contractGroupOffset,
1072 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1073 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex);
1075 template <
bool is_
internal_block,
typename OutPtr>
1077 const cl::sycl::nd_item<1> &itemID,
const VectorMapper &vec,
const TensorMapper &mat, OutScalar *local_output,
1080 OutScalar *scratch_ptr,
const StorageIndex contractGroupOffset,
1082 const StorageIndex nonContractGroupOffset,
const StorageIndex linearLocalThreadId, StorageIndex contractDim,
1083 StorageIndex nonContractDim, StorageIndex contractId, StorageIndex nonContractId,
1084 StorageIndex globalContractDimOffset, StorageIndex globalNonContractDimOffset, StorageIndex outScratchIndex) {
1085 OutScalar outScalar[Properties::WorkLoadPerThreadNC] = {OutScalar(0)};
1087 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1088 const StorageIndex vectorOffset = contractGroupOffset + linearLocalThreadId;
1089 extract_block<VecBlockProperties, is_internal_block, KFactor,
1090 Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC>(vec, scratch_ptr, linearLocalThreadId,
1091 vectorOffset, contractDim);
1093 itemID.barrier(cl::sycl::access::fence_space::local_space);
1094 auto in_scratch_ptr = scratch_ptr + contractId;
1097 StorageIndex privateOffsetC = 0;
1099 for (StorageIndex i = 0;
i < Properties::WorkLoadPerThreadC;
i++) {
1100 StorageIndex privateOffsetNC = 0;
1101 bool contract_conds = ((globalContractDimOffset + privateOffsetC) < contractDim);
1102 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1103 auto vecScalar = *in_scratch_ptr;
1105 auto vecScalar = (check_boundary<is_internal_block>(contract_conds))
1106 ? vec(is_lhs_vec ? StorageIndex(0) : globalContractDimOffset + privateOffsetC,
1107 is_lhs_vec ? globalContractDimOffset + privateOffsetC : StorageIndex(0))
1111 for (StorageIndex j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1112 auto matScalar = (check_boundary<is_internal_block>(
1113 contract_conds && ((globalNonContractDimOffset + privateOffsetNC) < nonContractDim)))
1114 ?
mat(is_lhs_vec ? globalContractDimOffset + privateOffsetC
1115 : globalNonContractDimOffset + privateOffsetNC,
1116 is_lhs_vec ? globalNonContractDimOffset + privateOffsetNC
1117 : globalContractDimOffset + privateOffsetC)
1120 outScalar[
j] = cl::sycl::mad(matScalar, vecScalar, outScalar[j]);
1121 privateOffsetNC += Properties::LocalThreadSizeNC;
1123 privateOffsetC += Properties::LocalThreadSizeC;
1124 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1125 in_scratch_ptr += Properties::LocalThreadSizeC;
1129 auto out_scratch_ptr = local_output + outScratchIndex;
1132 for (StorageIndex j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1133 *out_scratch_ptr = outScalar[
j];
1135 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1138 nonContractId = linearLocalThreadId % Properties::LocalThreadSizeNC;
1139 contractId = linearLocalThreadId / Properties::LocalThreadSizeNC;
1140 outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1143 out_scratch_ptr = local_output + outScratchIndex;
1145 for (StorageIndex j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1147 for (StorageIndex offset = Properties::LocalThreadSizeC >> 1; offset > 0; offset >>= 1) {
1148 itemID.barrier(cl::sycl::access::fence_space::local_space);
1149 if (contractId < offset) {
1150 StorageIndex myNeigbourId = (Properties::LocalThreadSizeNC * offset);
1151 *out_scratch_ptr += out_scratch_ptr[myNeigbourId];
1155 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1158 if (contractId == 0) {
1159 out_scratch_ptr = local_output + nonContractId;
1160 StorageIndex global_final_offset = nonContractGroupOffset + nonContractId;
1161 out_ptr += global_final_offset;
1163 for (StorageIndex j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1164 if (check_boundary<is_internal_block>(global_final_offset < nonContractDim)) {
1165 auto res = *out_scratch_ptr;
1168 out_ptr += Properties::LocalThreadSizeNC;
1171 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1172 if (!(is_internal_block)) global_final_offset += Properties::LocalThreadSizeNC;
1177 template <
typename InputBlockProperties,
bool is_internal_block,
int CFactor,
int GroupSize,
typename Input,
1179 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void extract_block(
const Input &inpt, Local *local_ptr,
1180 const StorageIndex &linearLocalThreadId,
1181 const StorageIndex &cOffset,
const StorageIndex &C) {
1182 local_ptr += InputBlockProperties::c_stride * linearLocalThreadId;
1183 StorageIndex cIndex = cOffset;
1184 for (StorageIndex cId = 0; cId < CFactor / InputBlockProperties::c_stride; cId++) {
1185 if (check_boundary<is_internal_block>(cIndex + InputBlockProperties::c_stride - 1 < C)) {
1186 auto val =
read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
1187 InputBlockProperties::is_rhs,
typename InputBlockProperties::OutType>(inpt, StorageIndex(0),
1188 cIndex, StorageIndex(1));
1189 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr);
1192 for (StorageIndex i = 0;
i < InputBlockProperties::elements_per_access;
i++) {
1195 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
1196 inpt, StorageIndex(0), cIndex +
i, StorageIndex(1))
1198 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr + i);
1201 local_ptr += InputBlockProperties::c_stride * GroupSize;
1202 cIndex += InputBlockProperties::c_stride * GroupSize;
1208 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1241 template <
typename OutScalar,
typename LhsScalar,
typename RhsScalar,
typename OutAccessor,
typename LhsMapper,
1242 typename RhsMapper,
typename StorageIndex,
bool Vectorizable>
1243 struct GeneralScalarContraction {
1244 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1246 const LhsMapper lhs;
1247 const RhsMapper rhs;
1248 OutAccessor out_res;
1249 const StorageIndex rng;
1252 GeneralScalarContraction(Scratch scratch_,
const LhsMapper lhs_,
const RhsMapper rhs_, OutAccessor out_res_,
1253 const StorageIndex rng_)
1254 : scratch(scratch_), lhs(lhs_), rhs(rhs_), out_res(out_res_), rng(rng_) {}
1258 auto out_ptr = out_res;
1259 OutScalar * scratch_ptr = scratch.get_pointer();
1261 StorageIndex globalid = itemID.get_global_id(0);
1262 StorageIndex localid = itemID.get_local_id(0);
1263 OutScalar accumulator = OutScalar(0);
1264 for (StorageIndex i = globalid;
i < rng;
i += itemID.get_global_range(0)) {
1265 accumulator = cl::sycl::mad(lhs(0, i), rhs(i, 0), accumulator);
1267 auto out_scratch_ptr = scratch_ptr + localid;
1268 *out_scratch_ptr = accumulator;
1269 for (StorageIndex offset = itemID.get_local_range(0) >> 1; offset > 0; offset >>= 1) {
1270 itemID.barrier(cl::sycl::access::fence_space::local_space);
1271 if (localid < offset) {
1272 *out_scratch_ptr = (accumulator += out_scratch_ptr[offset]);
1276 out_ptr[itemID.get_group(0)] = accumulator;
1285 template <
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
1289 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Eigen::SyclDevice>> {
1290 static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
1291 "SYCL tensor contraction does not support output kernels.");
1298 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
1312 BlockAccess =
false,
1316 static constexpr
int LDims = Base::LDims;
1317 static constexpr
int RDims = Base::RDims;
1318 static constexpr
int ContractDims = Base::ContractDims;
1327 static constexpr
int NumDims = LDims + RDims - 2 * ContractDims;
1333 typedef std::remove_const_t<typename LeftEvaluator::CoeffReturnType>
LhsScalar;
1334 typedef std::remove_const_t<typename RightEvaluator::CoeffReturnType>
RhsScalar;
1339 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered>
1340 struct input_mapper_propertis {
1341 static EIGEN_CONSTEXPR bool is_lhs_matrix = (LDims == 2 && ContractDims == 1) || lhs_inner_dim_contiguous;
1343 (RDims == 2 && ContractDims == 1) || (rhs_inner_dim_contiguous && !rhs_inner_dim_reordered);
1350 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1351 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1353 this->m_result = this->
m_device.get(
1354 static_cast<Scalar *
>(this->
m_device.allocate_temp(this->dimensions().TotalSize() *
sizeof(
Scalar))));
1355 data = this->m_result;
1358 return (this->m_result != NULL);
1362 if (this->m_lhs_inner_dim_contiguous) {
1363 if (this->m_rhs_inner_dim_contiguous) {
1364 if (this->m_rhs_inner_dim_reordered) {
1365 evalTyped<true, true, true, Unaligned>(buffer);
1367 evalTyped<true, true, false, Unaligned>(buffer);
1370 if (this->m_rhs_inner_dim_reordered) {
1371 evalTyped<true, false, true, Unaligned>(buffer);
1373 evalTyped<true, false, false, Unaligned>(buffer);
1377 if (this->m_rhs_inner_dim_contiguous) {
1378 if (this->m_rhs_inner_dim_reordered) {
1379 evalTyped<false, true, true, Unaligned>(buffer);
1381 evalTyped<false, true, false, Unaligned>(buffer);
1384 if (this->m_rhs_inner_dim_reordered) {
1385 evalTyped<false, false, true, Unaligned>(buffer);
1387 evalTyped<false, false, false, Unaligned>(buffer);
1393 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1395 const auto triple_dim = TripleDim{this->m_i_size, this->m_j_size, this->m_k_size};
1396 typedef internal::TensorContractionInputMapper<
1408 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1409 this->m_left_contracting_strides, this->m_k_strides);
1411 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1412 this->m_right_contracting_strides, this->m_k_strides);
1414 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1415 if (triple_dim.M == 1 && triple_dim.N == 1) {
1416 launchSC(buffer, lhs, rhs, triple_dim.K);
1419 #ifndef EIGEN_SYCL_DISABLE_GEMV
1420 if (triple_dim.M != 1 && triple_dim.N == 1) {
1421 LaunchVT<false>(buffer, rhs, lhs, triple_dim.M, triple_dim.K);
1422 }
else if (triple_dim.M == 1 && triple_dim.N != 1) {
1423 LaunchVT<true>(buffer, lhs, rhs, triple_dim.N, triple_dim.K);
1427 typedef input_mapper_propertis<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>
1428 inpt_mapper_properties;
1429 #ifndef EIGEN_SYCL_DISABLE_SKINNY
1430 bool skinny =
false;
1431 auto platform_name = this->device().getPlatformName();
1433 if (platform_name.find(
"AMD") == 0) {
1434 skinny = (triple_dim.M < triple_dim.K || triple_dim.N < triple_dim.K) &&
1435 ((triple_dim.M < 1024 && triple_dim.N < 1024) ||
1438 skinny = (((
std::max(triple_dim.K, triple_dim.N) /
std::min(triple_dim.K, triple_dim.N)) > 100) ||
1439 ((
std::max(triple_dim.K, triple_dim.M) /
std::min(triple_dim.K, triple_dim.M)) > 100) ||
1440 ((
std::max(triple_dim.N, triple_dim.M) /
std::min(triple_dim.N, triple_dim.M)) > 100));
1443 adjustTT<true, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1446 adjustTT<false, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1450 template <
bool skinny,
typename input_mapper_properties,
typename LhsMapper,
typename RhsMapper>
1452 const TripleDim &triple_dim)
const {
1453 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1454 if (device().has_local_memory()) {
1455 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 16> PanelParameters;
1456 launchTT<TensorSycl::internal::contraction_type::local, skinny, input_mapper_properties, PanelParameters>(
1457 buffer, lhs, rhs, triple_dim);
1460 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_OFF
1461 if (!(device().has_local_memory())) {
1462 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 4> PanelParameters;
1463 launchTT<TensorSycl::internal::contraction_type::no_local, skinny, input_mapper_properties, PanelParameters>(
1464 buffer, lhs, rhs, triple_dim);
1470 typename Properties,
typename LhsMapper,
typename RhsMapper>
1472 const TripleDim &triple_dim)
const {
1473 const StorageIndex roundUpM = Eigen::TensorSycl::internal::roundUp(triple_dim.M, Properties::TileSizeDimM);
1474 const StorageIndex roundUpN = Eigen::TensorSycl::internal::roundUp(triple_dim.N, Properties::TileSizeDimN);
1475 const StorageIndex groupSizeM = roundUpM / Properties::TileSizeDimM;
1476 const StorageIndex groupSizeN = roundUpN / Properties::TileSizeDimN;
1478 const StorageIndex roundUpK = Eigen::TensorSycl::internal::roundUp(triple_dim.K, Properties::TileSizeDimK);
1479 StorageIndex totalTilesK = roundUpK / Properties::TileSizeDimK;
1483 (
StorageIndex)(device().getPowerOfTwo(device().getNumSyclMultiProcessors(),
true) * 4) /
1484 (groupSizeM * groupSizeN)),
1488 const StorageIndex numTilesPerGroup = Eigen::TensorSycl::internal::roundUp(totalTilesK, groupSizeK) / groupSizeK;
1490 const StorageIndex totalGroupSize = groupSizeM * groupSizeN * groupSizeK;
1492 const StorageIndex localRange = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
1493 const StorageIndex globalRange = totalGroupSize * localRange;
1496 ? ((Properties::DoubleBuffer + 1) *
1497 (Properties::TileSizeDimM + Properties::BC) * (Properties::TileSizeDimK)) +
1498 ((Properties::DoubleBuffer + 1) * (Properties::TileSizeDimK) *
1499 (Properties::TileSizeDimN + Properties::BC))
1502 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1503 if (groupSizeK == 1) {
1505 LhsMapper, RhsMapper,
StorageIndex, Properties, TripleDim,
1508 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1509 lhs, rhs, buffer, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, triple_dim).wait();
1512 LhsMapper, RhsMapper,
StorageIndex, Properties, TripleDim,
1516 device().allocate_temp(triple_dim.M * triple_dim.N * groupSizeK *
sizeof(
CoeffReturnType)));
1519 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1520 lhs, rhs, tmp_global_accessor, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup,
1523 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1529 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1530 tmp_global_accessor, buffer,
1532 Eigen::TensorSycl::internal::roundUp(triple_dim.M * triple_dim.N, localRange))),
1533 cl::sycl::range<1>(localRange)),
1535 device().deallocate_temp(temp_pointer);
1539 #ifndef EIGEN_SYCL_DISABLE_GEMV
1540 template <
bool is_lhs_vec,
typename VectorMapper,
typename TensorMapper,
typename StorageIndex>
1547 typedef Eigen::TensorSycl::internal::TVPanelSize<CoeffReturnType, StorageIndex, NCWindow, CFactor, NCFactor>
1549 const StorageIndex roundUpC = Eigen::TensorSycl::internal::roundUp(C, Properties::TileSizeDimC);
1550 const StorageIndex cNumGroups = roundUpC / (Properties::LocalThreadSizeC * Properties::WorkLoadPerThreadC);
1551 const StorageIndex roundUpNC = Eigen::TensorSycl::internal::roundUp(nonContractDim, Properties::TileSizeDimNC);
1552 const StorageIndex nCNumGroups = roundUpNC / (Properties::LocalThreadSizeNC * Properties::WorkLoadPerThreadNC);
1554 (roundUpNC / (Properties::WorkLoadPerThreadNC)) * (roundUpC / (Properties::WorkLoadPerThreadC));
1555 const StorageIndex localRange = Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC;
1557 (Properties::WorkLoadPerThreadNC + CFactor) * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1558 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1559 if (cNumGroups > 1) {
1561 TensorMapper,
StorageIndex, Properties, CFactor,
false,
1568 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1569 vec,
mat, tmp_global_accessor, thread_range, scratchSize, nCNumGroups, nonContractDim, C).wait();
1571 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1576 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1577 tmp_global_accessor, buffer,
1578 cl::sycl::nd_range<1>(cl::sycl::range<1>(Eigen::TensorSycl::internal::roundUp(nonContractDim, localRange)),
1579 cl::sycl::range<1>(localRange)),
1580 StorageIndex(1), Op(), nonContractDim, cNumGroups).wait();
1581 device().deallocate_temp(temp_pointer);
1584 TensorMapper,
StorageIndex, Properties, CFactor,
false,
1587 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1588 vec,
mat, buffer, thread_range, scratchSize, nCNumGroups, nonContractDim, C).wait();
1594 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1595 template <
typename LhsMapper,
typename RhsMapper>
1599 (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)),
1600 "The Local thread size must be a power of 2 for the reduction "
1606 const StorageIndex num_work_group = ((K + (512 * local_range - 1)) / (512 * local_range) > 1 ? local_range : 1);
1607 const StorageIndex global_range = num_work_group * local_range;
1609 typedef Eigen::TensorSycl::internal::GeneralScalarContraction<
1612 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
1613 if (num_work_group > 1) {
1617 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, tmp_global_accessor,
1618 thread_range, local_range, K).wait();
1619 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1623 device().template unary_kernel_launcher<CoeffReturnType, GenericRKernel>(
1624 tmp_global_accessor, buffer,
1625 cl::sycl::nd_range<1>(cl::sycl::range<1>(local_range), cl::sycl::range<1>(local_range)), local_range, Op()).wait();
1626 device().deallocate_temp(temp_pointer);
1628 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, buffer, thread_range,
1629 local_range, K).wait();
1635 this->m_leftImpl.cleanup();
1636 this->m_rightImpl.cleanup();
1638 if (this->m_result) {
1639 this->
m_device.deallocate_temp(this->m_result);
1640 this->m_result = NULL;
RowXpr row(Index i) const
ColXpr col(Index i) const
IndexedView_or_VectorBlock operator()(const Indices &indices)
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNROLL_LOOP
#define EIGEN_DEVICE_FUNC
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
#define EIGEN_STATIC_ASSERT(X, MSG)
#define EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
Eigen::internal::traits< TensorContractionOp >::Index Index
internal::gebp_traits< typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType >::ResScalar CoeffReturnType
static std::enable_if_t< dt !=data_source::global_mem, void > write(PacketType &packet_data, DataScalar ptr)
write, a template function used for storing the data to local memory. This function is used to guaran...
static std::enable_if_t< PacketLoad, PacketType > read(const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &ld)
read, a template function used for loading the data from global memory. This function is used to guar...
bool check_boundary(bool)
check_boundary: is used to check the edge condition for non-internal blocks.
bool check_boundary< false >(bool cond)
check_boundary: specialization of the check_boundary for non-internal blocks.
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
CleanedUpDerType< DerType >::type() min(const AutoDiffScalar< DerType > &x, const T &y)
CleanedUpDerType< DerType >::type() max(const AutoDiffScalar< DerType > &x, const T &y)
internal::packet_traits< Scalar >::type type
internal::traits< TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Eigen::SyclDevice > >::LeftArgType LeftArgType
internal::traits< TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Eigen::SyclDevice > >::RightArgType RightArgType
internal::traits< TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Eigen::SyclDevice > >::OutputKernelType OutputKernelType
internal::traits< TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Eigen::SyclDevice > >::Indices Indices
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
TensorContractionEvaluatorBase< Self > Base
void launchTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs, const TripleDim &triple_dim) const
bool evalSubExprsIfNeeded(typename Base::EvaluatorPointerType data)
void evalTyped(typename Base::EvaluatorPointerType buffer) const
PacketType< CoeffReturnType, Device >::type PacketReturnType
array< StorageIndex, RDims > right_dim_mapper_t
array< StorageIndex, LDims - ContractDims > left_nocontract_t
array< StorageIndex, RDims - ContractDims > right_nocontract_t
DSizes< StorageIndex, NumDims > Dimensions
void EIGEN_ALWAYS_INLINE adjustTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs, const TripleDim &triple_dim) const
void evalToSycl(typename Base::EvaluatorPointerType buffer) const
array< StorageIndex, LDims > left_dim_mapper_t
XprType::CoeffReturnType CoeffReturnType
TensorEvaluator(const XprType &op, const Device &device)
TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Device > Self
Base::EvaluatorPointerType EvaluatorPointerType
TensorEvaluator< typename Base::EvalRightArgType, Device > RightEvaluator
std::remove_const_t< typename RightEvaluator::CoeffReturnType > RhsScalar
const Eigen::SyclDevice & device() const
XprType::Index StorageIndex
void EIGEN_ALWAYS_INLINE LaunchVT(EvaluatorPointerType buffer, const VectorMapper &vec, const TensorMapper &mat, StorageIndex NC, StorageIndex C) const
array< StorageIndex, ContractDims > contract_t
LeftEvaluator::Dimensions LeftDimensions
TensorEvaluator< typename Base::EvalLeftArgType, Device > LeftEvaluator
RightEvaluator::Dimensions RightDimensions
std::remove_const_t< typename XprType::Scalar > Scalar
EIGEN_ALWAYS_INLINE void launchSC(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs, StorageIndex K) const
std::remove_const_t< typename LeftEvaluator::CoeffReturnType > LhsScalar
TripleDim(const StorageIndex M_, const StorageIndex N_, const StorageIndex K_)
A cost model used to limit the number of threads used for evaluating tensor expression.
static constexpr int Layout
const Device EIGEN_DEVICE_REF m_device
Storage::Type EvaluatorPointerType
EvaluatorPointerType data() const
Derived::Scalar CoeffReturnType
Derived::Dimensions Dimensions