11 #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
14 #ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15 #define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
20 #if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
21 #define EIGEN_ALTIVEC_DISABLE_MMA 0
25 #if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
26 #if __has_builtin(__builtin_mma_assemble_acc)
27 #define EIGEN_ALTIVEC_MMA_SUPPORT
32 #if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
34 #if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
35 #define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
39 #if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
40 #define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
42 #elif defined(__MMA__)
43 #define EIGEN_ALTIVEC_MMA_ONLY 1
48 #if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
57 #include "../../InternalHeaderCheck.h"
66 template<
typename Scalar>
69 typedef typename packet_traits<Scalar>::type vectortype;
70 typedef PacketBlock<vectortype,4> type;
71 typedef vectortype rhstype;
81 struct quad_traits<double>
84 typedef PacketBlock<vectortype,4> type;
85 typedef PacketBlock<Packet2d,2> rhstype;
95 struct quad_traits<bfloat16>
98 typedef PacketBlock<vectortype,4> type;
99 typedef vectortype rhstype;
140 template<
typename Scalar,
int StorageOrder>
143 std::complex<Scalar>
v;
159 template<
typename Scalar,
int StorageOrder,
int N>
163 const_blas_data_mapper<std::complex<Scalar>,
Index, StorageOrder> rhs(_rhs, rhsStride);
164 const Index vectorSize = N*quad_traits<Scalar>::vectorsize;
165 const Index vectorDelta = vectorSize *
rows;
166 Scalar* blockBf =
reinterpret_cast<Scalar *
>(blockB);
168 Index rir = 0, rii,
j = 0;
169 for(;
j + vectorSize <=
cols;
j+=vectorSize)
171 rii = rir + vectorDelta;
175 for(
Index k = 0; k < vectorSize; k++)
177 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(
i,
j + k, rhs);
179 blockBf[rir + k] =
v.real();
180 blockBf[rii + k] =
v.imag();
195 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(
i,
j, rhs);
197 blockBf[rir] =
v.real();
198 blockBf[rii] =
v.imag();
208 template<
typename Scalar,
int StorageOrder>
212 const_blas_data_mapper<std::complex<Scalar>,
Index, StorageOrder> lhs(_lhs, lhsStride);
213 const Index vectorSize = quad_traits<Scalar>::vectorsize;
214 const Index vectorDelta = vectorSize * depth;
215 Scalar* blockAf =
reinterpret_cast<Scalar *
>(blockA);
217 Index rir = 0, rii,
j = 0;
218 for(;
j + vectorSize <=
rows;
j+=vectorSize)
220 rii = rir + vectorDelta;
224 for(
Index k = 0; k < vectorSize; k++)
226 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(
j+k,
i, lhs);
228 blockAf[rir + k] =
v.real();
229 blockAf[rii + k] =
v.imag();
240 rii = rir + ((
rows -
j) * depth);
247 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(k,
i, lhs);
249 blockAf[rir] =
v.real();
250 blockAf[rii] =
v.imag();
259 template<
typename Scalar,
int StorageOrder,
int N>
263 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
264 const Index vectorSize = quad_traits<Scalar>::vectorsize;
267 for(;
j + N*vectorSize <=
cols;
j+=N*vectorSize)
270 for(;
i < depth;
i++)
272 for(
Index k = 0; k < N*vectorSize; k++)
275 blockB[ri + k] = rhs(
j+k,
i);
277 blockB[ri + k] = rhs(
i,
j+k);
288 blockB[ri] = rhs(
i,
j);
290 blockB[ri] = rhs(
j,
i);
296 template<
typename Scalar,
int StorageOrder>
300 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
301 const Index vectorSize = quad_traits<Scalar>::vectorsize;
304 for(;
j + vectorSize <=
rows;
j+=vectorSize)
308 for(;
i < depth;
i++)
310 for(
Index k = 0; k < vectorSize; k++)
313 blockA[ri + k] = lhs(
j+k,
i);
315 blockA[ri + k] = lhs(
i,
j+k);
329 blockA[ri] = lhs(k,
i);
331 blockA[ri] = lhs(
i, k);
338 template<
typename Index,
int nr,
int StorageOrder>
339 struct symm_pack_rhs<
std::complex<float>,
Index, nr, StorageOrder>
343 symm_pack_complex_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride,
rows,
cols, k2);
347 template<
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
348 struct symm_pack_lhs<
std::complex<float>,
Index, Pack1, Pack2_dummy, StorageOrder>
352 symm_pack_complex_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
358 template<
typename Index,
int nr,
int StorageOrder>
359 struct symm_pack_rhs<
std::complex<double>,
Index, nr, StorageOrder>
363 symm_pack_complex_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride,
rows,
cols, k2);
367 template<
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
368 struct symm_pack_lhs<
std::complex<double>,
Index, Pack1, Pack2_dummy, StorageOrder>
372 symm_pack_complex_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
377 template<
typename Index,
int nr,
int StorageOrder>
378 struct symm_pack_rhs<float,
Index, nr, StorageOrder>
382 symm_pack_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride,
rows,
cols, k2);
386 template<
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
387 struct symm_pack_lhs<float,
Index, Pack1, Pack2_dummy, StorageOrder>
391 symm_pack_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
396 template<
typename Index,
int nr,
int StorageOrder>
397 struct symm_pack_rhs<double,
Index, nr, StorageOrder>
401 symm_pack_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride,
rows,
cols, k2);
405 template<
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
406 struct symm_pack_lhs<double,
Index, Pack1, Pack2_dummy, StorageOrder>
410 symm_pack_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
425 template<
typename Scalar,
typename Packet,
int N>
429 pstore<Scalar>(to + (0 *
size),
block.packet[0]);
430 pstore<Scalar>(to + (1 *
size),
block.packet[1]);
432 pstore<Scalar>(to + (2 *
size),
block.packet[2]);
435 pstore<Scalar>(to + (3 *
size),
block.packet[3]);
440 template<
typename Scalar,
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
bool PanelMode,
bool UseLhs>
444 const Index vectorSize = quad_traits<Scalar>::vectorsize;
445 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
446 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
447 Scalar* blockAt =
reinterpret_cast<Scalar *
>(blockA);
450 for(;
j + vectorSize <=
rows;
j+=vectorSize)
452 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(
j, 0) : lhs.getSubMapper(0,
j);
455 rii = rir + vectorDelta;
457 for(;
i + vectorSize <= depth;
i+=vectorSize)
459 PacketBlock<Packet,4> blockr, blocki;
460 PacketBlock<PacketC,8> cblock;
463 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0,
i);
465 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2,
i, 0);
468 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v,
p16uc_GETREAL32);
469 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v,
p16uc_GETREAL32);
470 blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v,
p16uc_GETREAL32);
471 blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v,
p16uc_GETREAL32);
473 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v,
p16uc_GETIMAG32);
474 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v,
p16uc_GETIMAG32);
475 blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v,
p16uc_GETIMAG32);
476 blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v,
p16uc_GETIMAG32);
480 blocki.packet[0] = -blocki.packet[0];
481 blocki.packet[1] = -blocki.packet[1];
482 blocki.packet[2] = -blocki.packet[2];
483 blocki.packet[3] = -blocki.packet[3];
486 if(((StorageOrder ==
RowMajor) && UseLhs) || (((StorageOrder ==
ColMajor) && !UseLhs)))
492 storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
493 storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
498 for(;
i < depth;
i++)
500 PacketBlock<Packet,1> blockr, blocki;
501 PacketBlock<PacketC,2> cblock;
503 if(((StorageOrder ==
ColMajor) && UseLhs) || (((StorageOrder ==
RowMajor) && !UseLhs)))
506 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0,
i);
507 cblock.packet[1] = lhs2.template loadPacket<PacketC>(2,
i);
509 cblock.packet[0] = lhs2.template loadPacket<PacketC>(
i, 0);
510 cblock.packet[1] = lhs2.template loadPacket<PacketC>(
i, 2);
514 cblock.packet[0] =
pload2(lhs2(0,
i), lhs2(1,
i));
515 cblock.packet[1] =
pload2(lhs2(2,
i), lhs2(3,
i));
517 cblock.packet[0] =
pload2(lhs2(
i, 0), lhs2(
i, 1));
518 cblock.packet[1] =
pload2(lhs2(
i, 2), lhs2(
i, 3));
522 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v,
p16uc_GETREAL32);
523 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v,
p16uc_GETIMAG32);
527 blocki.packet[0] = -blocki.packet[0];
530 pstore<Scalar>(blockAt + rir, blockr.packet[0]);
531 pstore<Scalar>(blockAt + rii, blocki.packet[0]);
537 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
542 if(PanelMode) rir -= (offset*(vectorSize - 1));
546 const DataMapper lhs2 = lhs.getSubMapper(0,
j);
547 rii = rir + ((PanelMode) ? stride : depth);
551 blockAt[rir] = lhs2(
i, 0).real();
554 blockAt[rii] = -lhs2(
i, 0).imag();
556 blockAt[rii] = lhs2(
i, 0).imag();
562 rir += ((PanelMode) ? (2*stride - depth) : depth);
567 if(PanelMode) rir += (offset*(
rows -
j - vectorSize));
568 rii = rir + (((PanelMode) ? stride : depth) * (
rows -
j));
575 blockAt[rir] = lhs(k,
i).real();
578 blockAt[rii] = -lhs(k,
i).imag();
580 blockAt[rii] = lhs(k,
i).imag();
592 template<
typename Scalar,
typename DataMapper,
typename Packet,
int StorageOrder,
bool PanelMode,
bool UseLhs>
596 const Index vectorSize = quad_traits<Scalar>::vectorsize;
599 for(;
j + vectorSize <=
rows;
j+=vectorSize)
601 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(
j, 0) : lhs.getSubMapper(0,
j);
604 if(PanelMode) ri += vectorSize*offset;
606 for(;
i + vectorSize <= depth;
i+=vectorSize)
608 PacketBlock<Packet,4>
block;
611 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(
block, lhs2, 0,
i);
613 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(
block, lhs2,
i, 0);
615 if(((StorageOrder ==
RowMajor) && UseLhs) || ((StorageOrder ==
ColMajor) && !UseLhs))
620 storeBlock<Scalar, Packet, 4>(blockA + ri,
block);
624 for(;
i < depth;
i++)
626 if(((StorageOrder ==
RowMajor) && UseLhs) || ((StorageOrder ==
ColMajor) && !UseLhs))
629 blockA[ri+0] = lhs2(0,
i);
630 blockA[ri+1] = lhs2(1,
i);
631 blockA[ri+2] = lhs2(2,
i);
632 blockA[ri+3] = lhs2(3,
i);
634 blockA[ri+0] = lhs2(
i, 0);
635 blockA[ri+1] = lhs2(
i, 1);
636 blockA[ri+2] = lhs2(
i, 2);
637 blockA[ri+3] = lhs2(
i, 3);
642 lhsV = lhs2.template loadPacket<Packet>(0,
i);
644 lhsV = lhs2.template loadPacket<Packet>(
i, 0);
646 pstore<Scalar>(blockA + ri, lhsV);
652 if(PanelMode) ri += vectorSize*(stride - offset - depth);
657 if(PanelMode) ri += offset;
661 const DataMapper lhs2 = lhs.getSubMapper(0,
j);
664 blockA[ri] = lhs2(
i, 0);
668 if(PanelMode) ri += stride - depth;
673 if(PanelMode) ri += offset*(
rows -
j);
680 blockA[ri] = lhs(k,
i);
690 template<
typename DataMapper,
int StorageOrder,
bool PanelMode>
691 struct dhs_pack<double, DataMapper,
Packet2d, StorageOrder, PanelMode, true>
695 const Index vectorSize = quad_traits<double>::vectorsize;
698 for(;
j + vectorSize <=
rows;
j+=vectorSize)
700 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
703 if(PanelMode) ri += vectorSize*offset;
705 for(;
i + vectorSize <= depth;
i+=vectorSize)
707 PacketBlock<Packet2d,2>
block;
710 block.packet[0] = lhs2.template loadPacket<Packet2d>(0,
i);
711 block.packet[1] = lhs2.template loadPacket<Packet2d>(1,
i);
715 block.packet[0] = lhs2.template loadPacket<Packet2d>(0,
i + 0);
716 block.packet[1] = lhs2.template loadPacket<Packet2d>(0,
i + 1);
719 storeBlock<double, Packet2d, 2>(blockA + ri,
block);
723 for(;
i < depth;
i++)
727 blockA[ri+0] = lhs2(0,
i);
728 blockA[ri+1] = lhs2(1,
i);
730 Packet2d lhsV = lhs2.template loadPacket<Packet2d>(0,
i);
737 if(PanelMode) ri += vectorSize*(stride - offset - depth);
742 if(PanelMode) ri += offset*(
rows -
j);
749 blockA[ri] = lhs(k,
i);
758 template<
typename DataMapper,
int StorageOrder,
bool PanelMode>
759 struct dhs_pack<double, DataMapper,
Packet2d, StorageOrder, PanelMode, false>
763 const Index vectorSize = quad_traits<double>::vectorsize;
766 for(;
j + 2*vectorSize <=
cols;
j+=2*vectorSize)
768 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
771 if(PanelMode) ri += offset*(2*vectorSize);
773 for(;
i + vectorSize <= depth;
i+=vectorSize)
775 PacketBlock<Packet2d,4>
block;
778 PacketBlock<Packet2d,2> block1, block2;
779 block1.packet[0] = rhs2.template loadPacket<Packet2d>(
i, 0);
780 block1.packet[1] = rhs2.template loadPacket<Packet2d>(
i, 1);
781 block2.packet[0] = rhs2.template loadPacket<Packet2d>(
i, 2);
782 block2.packet[1] = rhs2.template loadPacket<Packet2d>(
i, 3);
792 block.packet[0] = rhs2.template loadPacket<Packet2d>(
i + 0, 0);
793 block.packet[1] = rhs2.template loadPacket<Packet2d>(
i + 0, 2);
794 block.packet[2] = rhs2.template loadPacket<Packet2d>(
i + 1, 0);
795 block.packet[3] = rhs2.template loadPacket<Packet2d>(
i + 1, 2);
797 storeBlock<double, Packet2d, 4>(blockB + ri,
block);
802 for(;
i < depth;
i++)
806 blockB[ri+0] = rhs2(
i, 0);
807 blockB[ri+1] = rhs2(
i, 1);
811 blockB[ri+0] = rhs2(
i, 2);
812 blockB[ri+1] = rhs2(
i, 3);
814 Packet2d rhsV = rhs2.template loadPacket<Packet2d>(
i, 0);
819 rhsV = rhs2.template loadPacket<Packet2d>(
i, 2);
825 if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
828 if(PanelMode) ri += offset;
832 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
835 blockB[ri] = rhs2(
i, 0);
839 if(PanelMode) ri += stride - depth;
845 template<
typename DataMapper,
int StorageOrder,
bool PanelMode>
846 struct dhs_pack<bfloat16, DataMapper,
Packet8bf, StorageOrder, PanelMode, true>
850 const Index vectorSize = quad_traits<bfloat16>::vectorsize;
853 for(;
j + 2*vectorSize <=
rows;
j+=2*vectorSize)
855 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
858 if(PanelMode) ri += 2*vectorSize*offset;
862 for(;
i + 2 <= depth;
i+=2)
864 PacketBlock<Packet8bf,4>
block;
866 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
867 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize,
i + 0);
868 block.packet[2] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 1);
869 block.packet[3] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize,
i + 1);
872 t0 = vec_mergeh(
block.packet[0].m_val,
block.packet[2].m_val);
873 t1 = vec_mergel(
block.packet[0].m_val,
block.packet[2].m_val);
874 block.packet[2] = vec_mergeh(
block.packet[1].m_val,
block.packet[3].m_val);
875 block.packet[3] = vec_mergel(
block.packet[1].m_val,
block.packet[3].m_val);
876 block.packet[0] = t0;
877 block.packet[1] = t1;
879 storeBlock<bfloat16, Packet8bf, 4>(blockA + ri,
block);
881 ri += 2*2*vectorSize;
885 PacketBlock<Packet8bf,2>
block;
887 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
888 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize,
i + 0);
890 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri,
block);
895 for(;
i + vectorSize <= depth;
i+=vectorSize)
897 PacketBlock<Packet8bf,8> block1, block2;
899 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize,
i);
900 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize,
i);
904 v1[0] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
905 v1[1] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
906 v1[2] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
907 v1[3] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
908 v1[4] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
909 v1[5] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
910 v1[6] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
911 v1[7] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
912 v2[0] = vec_mergeh(
reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
913 v2[1] = vec_mergel(
reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
914 v2[2] = vec_mergeh(
reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
915 v2[3] = vec_mergel(
reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
916 v2[4] = vec_mergeh(
reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
917 v2[5] = vec_mergel(
reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
918 v2[6] = vec_mergeh(
reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
919 v2[7] = vec_mergel(
reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
921 #ifdef EIGEN_VECTORIZE_VSX
958 pstore<bfloat16>(blockA + ri + (0 * vectorSize) + (2*vectorSize *
M), block1.packet[
M+0]);
959 pstore<bfloat16>(blockA + ri + (1 * vectorSize) + (2*vectorSize *
M), block1.packet[
M+1]);
960 pstore<bfloat16>(blockA + ri + (2 * vectorSize) + (2*vectorSize *
M), block2.packet[
M+0]);
961 pstore<bfloat16>(blockA + ri + (3 * vectorSize) + (2*vectorSize *
M), block2.packet[
M+1]);
964 ri += 2*vectorSize*vectorSize;
966 for(;
i + 2 <= depth;
i+=2)
968 for(
Index M = 0;
M < 2*vectorSize;
M++) {
969 blockA[ri + (
M * 2) + 0] = lhs2(
M,
i + 0);
970 blockA[ri + (
M * 2) + 1] = lhs2(
M,
i + 1);
973 ri += 2*2*vectorSize;
977 for(
Index M = 0;
M < 2*vectorSize;
M++) {
978 blockA[ri +
M] = lhs2(
M,
i);
984 if(PanelMode) ri += 2*vectorSize*(stride - offset - depth);
986 for(;
j + vectorSize <=
rows;
j+=vectorSize)
988 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
991 if(PanelMode) ri += vectorSize*offset;
995 for(;
i + 2 <= depth;
i+=2)
997 PacketBlock<Packet8bf,2>
block;
999 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
1000 block.packet[1] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 1);
1003 t0 = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1004 block.packet[1] = vec_mergel(
block.packet[0].m_val,
block.packet[1].m_val);
1005 block.packet[0] = t0;
1007 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri,
block);
1013 Packet8bf lhsV = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
1019 for(;
i + vectorSize <= depth;
i+=vectorSize)
1021 PacketBlock<Packet8bf,8> block1;
1023 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize,
i);
1028 v1[0] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1029 v1[1] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1030 v1[2] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1031 v1[3] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1032 v1[4] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1033 v1[5] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1034 v1[6] = vec_mergeh(
reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1035 v1[7] = vec_mergel(
reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1037 #ifdef EIGEN_VECTORIZE_VSX
1061 ri += vectorSize*vectorSize;
1063 for(;
i + 2 <= depth;
i+=2)
1065 for(
Index M = 0;
M < vectorSize;
M++) {
1066 blockA[ri + (
M * 2) + 0] = lhs2(
M,
i + 0);
1067 blockA[ri + (
M * 2) + 1] = lhs2(
M,
i + 1);
1074 for(
Index M = 0;
M < vectorSize;
M++) {
1075 blockA[ri +
M] = lhs2(
M,
i);
1082 if(PanelMode) ri += vectorSize*(stride - offset - depth);
1086 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
1089 if(PanelMode) ri += 4*offset;
1091 for(;
i + 2 <= depth;
i+=2)
1095 PacketBlock<Packet8bf,2>
block;
1097 block.packet[0] = lhs2.template loadPacketPartial<Packet8bf>(0,
i + 0, 4);
1098 block.packet[1] = lhs2.template loadPacketPartial<Packet8bf>(0,
i + 1, 4);
1100 block.packet[0] = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1104 blockA[ri+0] = lhs2(0,
i + 0);
1105 blockA[ri+1] = lhs2(0,
i + 1);
1106 blockA[ri+2] = lhs2(1,
i + 0);
1107 blockA[ri+3] = lhs2(1,
i + 1);
1108 blockA[ri+4] = lhs2(2,
i + 0);
1109 blockA[ri+5] = lhs2(2,
i + 1);
1110 blockA[ri+6] = lhs2(3,
i + 0);
1111 blockA[ri+7] = lhs2(3,
i + 1);
1120 Packet8bf lhsV = lhs2.template loadPacketPartial<Packet8bf>(0,
i + 0, 4);
1124 blockA[ri+0] = lhs2(0,
i);
1125 blockA[ri+1] = lhs2(1,
i);
1126 blockA[ri+2] = lhs2(2,
i);
1127 blockA[ri+3] = lhs2(3,
i);
1133 if(PanelMode) ri += 4*(stride - offset - depth);
1139 if(PanelMode) ri += offset*(
rows -
j);
1142 for(;
i + 2 <= depth;
i+=2)
1145 for(; k <
rows; k++)
1147 blockA[ri+0] = lhs(k,
i + 0);
1148 blockA[ri+1] = lhs(k,
i + 1);
1156 blockA[ri] = lhs(
j,
i);
1165 template<
typename DataMapper,
int StorageOrder,
bool PanelMode>
1166 struct dhs_pack<bfloat16, DataMapper,
Packet8bf, StorageOrder, PanelMode, false>
1170 const Index vectorSize = quad_traits<bfloat16>::vectorsize;
1173 for(;
j + 4 <=
cols;
j+=4)
1175 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
1178 if(PanelMode) ri += 4*offset;
1180 for(;
i + vectorSize <= depth;
i+=vectorSize)
1184 PacketBlock<Packet8bf,4>
block;
1186 bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(
block, rhs2,
i, 0);
1195 #ifdef EIGEN_VECTORIZE_VSX
1207 storeBlock<bfloat16, Packet8bf, 4>(blockB + ri,
block);
1209 PacketBlock<Packet8bf,8>
block;
1211 for (
int M = 0;
M < 8;
M++) {
1212 block.packet[
M] = rhs2.template loadPacketPartial<Packet8bf>(
i +
M, 0, 4);
1215 block.packet[0] = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1216 block.packet[1] = vec_mergeh(
block.packet[2].m_val,
block.packet[3].m_val);
1217 block.packet[2] = vec_mergeh(
block.packet[4].m_val,
block.packet[5].m_val);
1218 block.packet[3] = vec_mergeh(
block.packet[6].m_val,
block.packet[7].m_val);
1220 const Index size = 16 /
sizeof(bfloat16);
1222 for (
int M = 0;
M < 4;
M++) {
1229 for (;
i + 2 <= depth;
i += 2) {
1232 blockB[ri+0] = rhs2(
i + 0, 0);
1233 blockB[ri+1] = rhs2(
i + 1, 0);
1234 blockB[ri+2] = rhs2(
i + 0, 1);
1235 blockB[ri+3] = rhs2(
i + 1, 1);
1236 blockB[ri+4] = rhs2(
i + 0, 2);
1237 blockB[ri+5] = rhs2(
i + 1, 2);
1238 blockB[ri+6] = rhs2(
i + 0, 3);
1239 blockB[ri+7] = rhs2(
i + 1, 3);
1241 PacketBlock<Packet8bf,2>
block;
1243 for (
int M = 0;
M < 2;
M++) {
1244 block.packet[
M] = rhs2.template loadPacketPartial<Packet8bf>(
i +
M, 0, 4);
1247 block.packet[0] = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1256 blockB[ri+0] = rhs2(
i, 0);
1257 blockB[ri+1] = rhs2(
i, 1);
1258 blockB[ri+2] = rhs2(
i, 2);
1259 blockB[ri+3] = rhs2(
i, 3);
1264 if(PanelMode) ri += 4*(stride - offset - depth);
1269 if(PanelMode) ri += offset*(
cols -
j);
1272 for(;
i + 2 <= depth;
i+=2)
1275 for(; k <
cols; k++)
1277 blockB[ri+0] = rhs(
i + 0, k);
1278 blockB[ri+1] = rhs(
i + 1, k);
1286 blockB[ri] = rhs(
i,
j);
1295 template<
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1296 struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
1300 const Index vectorSize = quad_traits<double>::vectorsize;
1301 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
1302 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
1303 double* blockAt =
reinterpret_cast<double *
>(blockA);
1306 for(;
j + vectorSize <=
rows;
j+=vectorSize)
1308 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
1311 rii = rir + vectorDelta;
1313 for(;
i + vectorSize <= depth;
i+=vectorSize)
1315 PacketBlock<Packet,2> blockr, blocki;
1316 PacketBlock<PacketC,4> cblock;
1320 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0,
i + 0);
1321 cblock.packet[1] = lhs2.template loadPacket<PacketC>(0,
i + 1);
1323 cblock.packet[2] = lhs2.template loadPacket<PacketC>(1,
i + 0);
1324 cblock.packet[3] = lhs2.template loadPacket<PacketC>(1,
i + 1);
1326 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v);
1327 blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v);
1329 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
1330 blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
1332 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0,
i);
1333 cblock.packet[1] = lhs2.template loadPacket<PacketC>(1,
i);
1335 cblock.packet[2] = lhs2.template loadPacket<PacketC>(0,
i + 1);
1336 cblock.packet[3] = lhs2.template loadPacket<PacketC>(1,
i + 1);
1338 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1339 blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
1341 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1342 blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1347 blocki.packet[0] = -blocki.packet[0];
1348 blocki.packet[1] = -blocki.packet[1];
1351 storeBlock<double, Packet, 2>(blockAt + rir, blockr);
1352 storeBlock<double, Packet, 2>(blockAt + rii, blocki);
1354 rir += 2*vectorSize;
1355 rii += 2*vectorSize;
1357 for(;
i < depth;
i++)
1359 PacketBlock<Packet,1> blockr, blocki;
1360 PacketBlock<PacketC,2> cblock;
1362 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0,
i);
1363 cblock.packet[1] = lhs2.template loadPacket<PacketC>(1,
i);
1365 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1366 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1370 blocki.packet[0] = -blocki.packet[0];
1380 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
1385 if(PanelMode) rir += (offset*(
rows -
j - vectorSize));
1386 rii = rir + (((PanelMode) ? stride : depth) * (
rows -
j));
1391 for(; k <
rows; k++)
1393 blockAt[rir] = lhs(k,
i).real();
1396 blockAt[rii] = -lhs(k,
i).imag();
1398 blockAt[rii] = lhs(k,
i).imag();
1409 template<
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1410 struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
1414 const Index vectorSize = quad_traits<double>::vectorsize;
1415 const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth);
1416 Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii;
1417 double* blockBt =
reinterpret_cast<double *
>(blockB);
1420 for(;
j + 2*vectorSize <=
cols;
j+=2*vectorSize)
1422 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
1425 rii = rir + vectorDelta;
1427 for(;
i < depth;
i++)
1429 PacketBlock<PacketC,4> cblock;
1430 PacketBlock<Packet,2> blockr, blocki;
1432 bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2,
i, 0);
1434 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1435 blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
1437 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1438 blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1442 blocki.packet[0] = -blocki.packet[0];
1443 blocki.packet[1] = -blocki.packet[1];
1446 storeBlock<double, Packet, 2>(blockBt + rir, blockr);
1447 storeBlock<double, Packet, 2>(blockBt + rii, blocki);
1449 rir += 2*vectorSize;
1450 rii += 2*vectorSize;
1453 rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
1456 if(PanelMode) rir -= (offset*(2*vectorSize - 1));
1460 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
1461 rii = rir + ((PanelMode) ? stride : depth);
1465 blockBt[rir] = rhs2(
i, 0).real();
1468 blockBt[rii] = -rhs2(
i, 0).imag();
1470 blockBt[rii] = rhs2(
i, 0).imag();
1476 rir += ((PanelMode) ? (2*stride - depth) : depth);
1486 template<
typename Packet,
bool NegativeAccumulate,
int N>
1489 if(NegativeAccumulate)
1491 for (
int M = 0;
M < N;
M++) {
1492 acc->packet[
M] = vec_nmsub(lhsV, rhsV[
M], acc->packet[
M]);
1495 for (
int M = 0;
M < N;
M++) {
1496 acc->packet[
M] = vec_madd(lhsV, rhsV[
M], acc->packet[
M]);
1501 template<
int N,
typename Scalar,
typename Packet,
bool NegativeAccumulate>
1504 Packet lhsV = pload<Packet>(lhs);
1506 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1510 template<
int N,
typename Packet,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
1511 EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag,
const Packet &lhsV, Packet &lhsVi,
const Packet* rhsV,
const Packet* rhsVi)
1513 pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1516 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1520 pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1521 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1525 pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1529 template<
int N,
typename Scalar,
typename Packet,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
1530 EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag,
const Scalar* lhs_ptr,
const Scalar* lhs_ptr_imag,
const Packet* rhsV,
const Packet* rhsVi)
1532 Packet lhsV = ploadLhs<Packet>(lhs_ptr);
1534 if(!LhsIsReal) lhsVi = ploadLhs<Packet>(lhs_ptr_imag);
1537 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1540 template<
typename Packet>
1543 return ploadu<Packet>(lhs);
1547 template<
typename Packet,
int N>
1550 for (
int M = 0;
M < N;
M++) {
1555 template<
typename Packet,
int N>
1558 for (
int M = 0;
M < N;
M++) {
1559 acc.packet[
M] = vec_mul(accZ.packet[
M], pAlpha);
1563 template<
typename Packet,
int N>
1566 for (
int M = 0;
M < N;
M++) {
1567 acc.packet[
M] = pand<Packet>(acc.packet[
M], pMask);
1572 template<
typename Packet,
int N,
bool mask>
1573 EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag,
const Packet& bReal,
const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag,
const Packet& pMask)
1576 band<Packet, N>(aReal, pMask);
1577 band<Packet, N>(aImag, pMask);
1582 bscalec_common<Packet, N>(cReal, aReal, bReal);
1584 bscalec_common<Packet, N>(cImag, aImag, bReal);
1586 pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
1588 pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
1594 template<
typename DataMapper,
typename Packet, const Index accCols,
int StorageOrder,
bool Complex,
int N,
bool full>
1598 for (
int M = 0;
M < N;
M++) {
1599 acc.packet[
M] =
res.template loadPacket<Packet>(
row +
M,
col);
1602 for (
int M = 0;
M < N;
M++) {
1603 acc.packet[
M+N] =
res.template loadPacket<Packet>(
row +
M,
col + accCols);
1607 for (
int M = 0;
M < N;
M++) {
1608 acc.packet[
M] =
res.template loadPacket<Packet>(
row,
col +
M);
1610 if (Complex && full) {
1611 for (
int M = 0;
M < N;
M++) {
1612 acc.packet[
M+N] =
res.template loadPacket<Packet>(
row + accCols,
col +
M);
1618 template<
typename DataMapper,
typename Packet,
int N>
1621 for (
int M = 0;
M < N;
M++) {
1622 res.template storePacket<Packet>(
row,
M, acc.packet[
M]);
1626 #ifdef USE_PARTIAL_PACKETS
1627 template<
typename DataMapper,
typename Packet, const Index accCols,
bool Complex, Index N,
bool full>
1631 acc.packet[
M] =
res.template loadPacketPartial<Packet>(
row,
M, elements);
1633 if (Complex && full) {
1635 acc.packet[
M+N] =
res.template loadPacketPartial<Packet>(
row + accCols,
M, elements);
1640 template<
typename DataMapper,
typename Packet, Index N>
1644 res.template storePacketPartial<Packet>(
row,
M, acc.packet[
M], elements);
1650 #define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
1652 #define USE_P10_AND_PVIPR2_0 0
1655 #if !USE_P10_AND_PVIPR2_0
1656 const static Packet4i mask4[4] = { { 0, 0, 0, 0 }, { -1, 0, 0, 0 }, { -1, -1, 0, 0 }, { -1, -1, -1, 0 } };
1659 template<
typename Packet>
1662 #if USE_P10_AND_PVIPR2_0
1664 return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1)));
1666 return Packet(vec_genwm((1 << remaining_rows) - 1));
1669 return Packet(
mask4[remaining_rows]);
1676 #if USE_P10_AND_PVIPR2_0
1684 Packet2l ret = { -remaining_rows, 0 };
1689 template<
typename Packet,
int N>
1692 for (
int M = 0;
M < N;
M++) {
1693 acc.packet[
M] = pmadd<Packet>(pAlpha, accZ.packet[
M], acc.packet[
M]);
1698 template<
typename Packet,
int N,
bool mask>
1702 band<Packet, N>(accZ, pMask);
1707 bscale<Packet, N>(acc, accZ, pAlpha);
1710 template<
typename Packet,
int N,
bool real>
1713 Packet& a0, Packet& a1, Packet& a2, Packet& a3)
1715 a0 = pset1<Packet>(ap0[0]);
1717 a1 = pset1<Packet>(ap0[1]);
1718 a2 = pset1<Packet>(ap0[2]);
1719 a3 = pset1<Packet>(ap0[3]);
1724 a1 = pset1<Packet>(ap1[0]);
1730 a2 = pset1<Packet>(ap2[0]);
1758 a0 = vec_splat(a1, 0);
1759 a1 = vec_splat(a1, 1);
1760 a2 = vec_splat(a3, 0);
1761 a3 = vec_splat(a3, 1);
1765 template<
typename Packet,
typename Packetc,
int N,
bool full>
1768 for (
int M = 0;
M < N;
M++) {
1769 acc1.packet[
M].v = vec_mergeh(taccReal.packet[
M], taccImag.packet[
M]);
1773 for (
int M = 0;
M < N;
M++) {
1774 acc2.packet[
M].v = vec_mergel(taccReal.packet[
M], taccImag.packet[
M]);
1779 template<
typename Packet,
typename Packetc,
int N,
bool full>
1780 EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc,N*2>& tRes, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
1782 bcouple_common<Packet, Packetc, N, full>(taccReal, taccImag, acc1, acc2);
1784 for (
int M = 0;
M < N;
M++) {
1785 acc1.packet[
M] = padd<Packetc>(tRes.packet[
M], acc1.packet[
M]);
1789 for (
int M = 0;
M < N;
M++) {
1790 acc2.packet[
M] = padd<Packetc>(tRes.packet[
M+N], acc2.packet[
M]);
1799 #define MICRO_UNROLL(func) \
1800 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1802 #define MICRO_NORMAL_ROWS \
1803 accRows == quad_traits<Scalar>::rows || accRows == 1
1805 #define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1)
1807 #define MICRO_RHS(ptr, N) rhs_##ptr##N
1809 #define MICRO_ZERO_PEEL(peel) \
1810 if ((PEEL_ROW > peel) && (peel != 0)) { \
1811 bsetzero<Packet, accRows>(accZero##peel); \
1813 EIGEN_UNUSED_VARIABLE(accZero##peel); \
1816 #define MICRO_ADD(ptr, N) \
1817 if (MICRO_NORMAL_ROWS) { \
1818 MICRO_RHS(ptr,0) += (accRows * N); \
1820 MICRO_RHS(ptr,0) += N; \
1821 MICRO_RHS(ptr,1) += N; \
1822 if (accRows == 3) { \
1823 MICRO_RHS(ptr,2) += N; \
1827 #define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N)
1829 #define MICRO_BROADCAST1(peel, ptr, rhsV, real) \
1830 if (MICRO_NORMAL_ROWS) { \
1831 pbroadcastN<Packet,accRows,real>(MICRO_RHS(ptr,0) + (accRows * peel), MICRO_RHS(ptr,0), MICRO_RHS(ptr,0), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1833 pbroadcastN<Packet,accRows,real>(MICRO_RHS(ptr,0) + peel, MICRO_RHS(ptr,1) + peel, MICRO_RHS(ptr,2) + peel, rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1836 #define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true)
1838 #define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \
1839 pbroadcastN<Packet,accRows,real>(MICRO_RHS(ptr,0), MICRO_RHS(ptr,1), MICRO_RHS(ptr,2), rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1841 #define MICRO_BROADCAST_EXTRA \
1843 MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \
1846 #define MICRO_SRC2(ptr, N, M) \
1847 if (MICRO_NORMAL_ROWS) { \
1848 EIGEN_UNUSED_VARIABLE(strideB); \
1849 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,1)); \
1850 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,2)); \
1852 MICRO_RHS(ptr,1) = rhs_base + N + M; \
1853 if (accRows == 3) { \
1854 MICRO_RHS(ptr,2) = rhs_base + N*2 + M; \
1856 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,2)); \
1860 #define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0)
1862 #define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL)
1864 #define MICRO_WORK_PEEL(peel) \
1865 if (PEEL_ROW > peel) { \
1866 MICRO_BROADCAST(peel) \
1867 pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1869 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1872 #define MICRO_WORK_PEEL_ROW \
1873 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1874 MICRO_UNROLL(MICRO_WORK_PEEL) \
1875 lhs_ptr += (remaining_rows * PEEL_ROW); \
1876 MICRO_ADD_ROWS(PEEL_ROW)
1878 #define MICRO_ADD_PEEL(peel, sum) \
1879 if (PEEL_ROW > peel) { \
1880 for (Index i = 0; i < accRows; i++) { \
1881 accZero##sum.packet[i] += accZero##peel.packet[i]; \
1885 #define MICRO_ADD_PEEL_ROW \
1886 MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \
1887 MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1889 #define MICRO_PREFETCHN1(ptr, N) \
1890 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,0)); \
1891 if (N == 2 || N == 3) { \
1892 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,1)); \
1894 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,2)); \
1898 #define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N)
1900 #define MICRO_COMPLEX_PREFETCHN(N) \
1901 MICRO_PREFETCHN1(ptr_real, N); \
1903 MICRO_PREFETCHN1(ptr_imag, N); \
1906 template<
typename Scalar,
typename Packet, const Index accRows, const Index remaining_rows>
1908 const Scalar* &lhs_ptr,
1909 const Scalar* &rhs_ptr0,
1910 const Scalar* &rhs_ptr1,
1911 const Scalar* &rhs_ptr2,
1912 PacketBlock<Packet,accRows> &accZero)
1915 pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1916 lhs_ptr += remaining_rows;
1919 template<
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols, const Index remaining_rows>
1921 const DataMapper&
res,
1922 const Scalar* lhs_base,
1923 const Scalar* rhs_base,
1930 const Packet& pAlpha,
1931 const Packet& pMask)
1933 const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL;
1934 const Scalar* lhs_ptr = lhs_base +
row*strideA + remaining_rows*offsetA;
1935 PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
1938 bsetzero<Packet, accRows>(accZero0);
1952 for(; k < depth; k++)
1954 MICRO_EXTRA_ROW<Scalar, Packet, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
1957 #ifdef USE_PARTIAL_PACKETS
1960 bload_partial<DataMapper, Packet, 0, false, accRows>(acc,
res,
row, remaining_rows);
1961 bscale<Packet,accRows>(acc, accZero0, pAlpha);
1962 bstore_partial<DataMapper, Packet, accRows>(acc,
res,
row, remaining_rows);
1964 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc,
res,
row, 0);
1965 if ((accRows == 1) || (
rows >= accCols))
1967 bscale<Packet,accRows,true>(acc, accZero0, pAlpha, pMask);
1968 bstore<DataMapper, Packet, accRows>(acc,
res,
row);
1970 bscale<Packet,accRows,false>(acc, accZero0, pAlpha, pMask);
1971 for(
Index j = 0;
j < accRows;
j++) {
1972 for(
Index i = 0;
i < remaining_rows;
i++) {
1980 #define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
1983 MICRO_EXTRA_UNROLL(1) \
1986 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1987 MICRO_EXTRA_UNROLL(2) \
1991 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1992 MICRO_EXTRA_UNROLL(3) \
1997 #define MICRO_EXTRA_ROWS(N) \
1998 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, accRows, accCols, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
2000 template<
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols>
2002 const DataMapper&
res,
2003 const Scalar* lhs_base,
2004 const Scalar* rhs_base,
2011 Index remaining_rows,
2012 const Packet& pAlpha,
2013 const Packet& pMask)
2018 #define MICRO_UNROLL_WORK(func, func2, peel) \
2019 MICRO_UNROLL(func2); \
2020 func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
2021 func(4,peel) func(5,peel) func(6,peel) func(7,peel)
2023 #define MICRO_WORK_ONE(iter, peel) \
2024 if (unroll_factor > iter) { \
2025 pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
2028 #define MICRO_TYPE_PEEL4(func, func2, peel) \
2029 if (PEEL > peel) { \
2030 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
2031 MICRO_BROADCAST(peel) \
2032 MICRO_UNROLL_WORK(func, func2, peel) \
2034 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2037 #define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2038 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
2039 func(func1,func2,0) func(func1,func2,1) \
2040 func(func1,func2,2) func(func1,func2,3) \
2041 func(func1,func2,4) func(func1,func2,5) \
2042 func(func1,func2,6) func(func1,func2,7)
2044 #define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
2048 #define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \
2049 MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \
2050 MICRO_ADD_ROWS(size)
2052 #define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL)
2054 #define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1)
2056 #define MICRO_DST_PTR_ONE(iter) \
2057 if (unroll_factor > iter) { \
2058 bsetzero<Packet, accRows>(accZero##iter); \
2060 EIGEN_UNUSED_VARIABLE(accZero##iter); \
2063 #define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
2065 #define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
2067 #define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
2069 #ifdef USE_PARTIAL_PACKETS
2070 #define MICRO_STORE_ONE(iter) \
2071 if (unroll_factor > iter) { \
2072 if (MICRO_NORMAL_PARTIAL(iter)) { \
2073 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
2074 bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
2075 bstore<DataMapper, Packet, accRows>(acc, res, row + iter*accCols); \
2077 bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row + iter*accCols, accCols2); \
2078 bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
2079 bstore_partial<DataMapper, Packet, accRows>(acc, res, row + iter*accCols, accCols2); \
2083 #define MICRO_STORE_ONE(iter) \
2084 if (unroll_factor > iter) { \
2085 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
2086 bscale<Packet,accRows,!(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
2087 bstore<DataMapper, Packet, accRows>(acc, res, row + iter*accCols); \
2091 #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
2093 #ifdef USE_PARTIAL_PACKETS
2094 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols,
bool full>
2096 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols, const Index accCols2>
2099 const DataMapper&
res,
2100 const Scalar* lhs_base,
2101 const Scalar* rhs_base,
2107 const Packet& pAlpha,
2108 #ifdef USE_PARTIAL_PACKETS
2115 const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL;
2116 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
2117 PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
2118 PacketBlock<Packet,accRows> acc;
2131 for(; k < depth; k++)
2140 #ifdef USE_PARTIAL_PACKETS
2141 #define MICRO_UNROLL_ITER2(N, M) \
2142 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, !M>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, M ? remaining_rows : accCols); \
2145 #define MICRO_UNROLL_ITER2(N, M) \
2146 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
2150 template<
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols>
2152 const DataMapper&
res,
2153 const Scalar* blockA,
2154 const Scalar* blockB,
2162 Index remaining_rows,
2163 const Packet& pAlpha,
2164 const Packet& pMask)
2166 const DataMapper res3 =
res.getSubMapper(0,
col);
2169 const Scalar* lhs_base = blockA + accCols*offsetA;
2172 #define MAX_UNROLL 7
2176 switch( (
rows-
row)/accCols ) {
2217 if(remaining_rows > 0)
2219 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB,
row,
rows, remaining_rows, pAlpha, pMask);
2223 #define MICRO_EXTRA_COLS(N) \
2224 gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
2226 template<
typename Scalar,
typename Packet,
typename DataMapper, const Index accCols>
2228 const DataMapper&
res,
2229 const Scalar* blockA,
2230 const Scalar* blockB,
2239 Index remaining_rows,
2240 const Packet& pAlpha,
2241 const Packet& pMask)
2249 template<
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols>
2250 EIGEN_STRONG_INLINE
void gemm(
const DataMapper&
res,
const Scalar* blockA,
const Scalar* blockB,
Index rows,
Index depth,
Index cols, Scalar alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
2252 const Index remaining_rows =
rows % accCols;
2254 if( strideA == -1 ) strideA = depth;
2255 if( strideB == -1 ) strideB = depth;
2257 const Packet pAlpha = pset1<Packet>(alpha);
2258 const Packet pMask = bmask<Packet>(remaining_rows);
2263 gemm_cols<Scalar, Packet, DataMapper, accRows, accCols>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows, remaining_rows, pAlpha, pMask);
2268 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows,
cols, remaining_rows, pAlpha, pMask);
2272 #define accColsC (accCols / 2)
2273 #define advanceRows ((LhsIsReal) ? 1 : 2)
2274 #define advanceCols ((RhsIsReal) ? 1 : 2)
2277 #define PEEL_COMPLEX 3
2278 #define PEEL_COMPLEX_ROW 3
2280 #define MICRO_COMPLEX_UNROLL(func) \
2281 func(0) func(1) func(2) func(3)
2283 #define MICRO_COMPLEX_ZERO_PEEL(peel) \
2284 if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
2285 bsetzero<Packet, accRows>(accReal##peel); \
2286 bsetzero<Packet, accRows>(accImag##peel); \
2288 EIGEN_UNUSED_VARIABLE(accReal##peel); \
2289 EIGEN_UNUSED_VARIABLE(accImag##peel); \
2292 #define MICRO_COMPLEX_ADD_ROWS(N, used) \
2293 MICRO_ADD(ptr_real, N) \
2295 MICRO_ADD(ptr_imag, N) \
2296 } else if (used) { \
2297 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,0)); \
2298 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,1)); \
2299 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,2)); \
2302 #define MICRO_COMPLEX_BROADCAST(peel) \
2303 MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \
2305 MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \
2307 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2310 #define MICRO_COMPLEX_BROADCAST_EXTRA \
2311 Packet rhsV[4], rhsVi[4]; \
2312 MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \
2314 MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \
2316 EIGEN_UNUSED_VARIABLE(rhsVi); \
2318 MICRO_COMPLEX_ADD_ROWS(1, true)
2320 #define MICRO_COMPLEX_SRC2_PTR \
2321 MICRO_SRC2(ptr_real, strideB*advanceCols, 0) \
2323 MICRO_RHS(ptr_imag,0) = rhs_base + MICRO_NEW_ROWS*strideB; \
2324 MICRO_SRC2(ptr_imag, strideB*advanceCols, strideB) \
2326 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,0)); \
2327 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,1)); \
2328 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,2)); \
2331 #define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL)
2333 #define MICRO_COMPLEX_WORK_PEEL(peel) \
2334 if (PEEL_COMPLEX_ROW > peel) { \
2335 MICRO_COMPLEX_BROADCAST(peel) \
2336 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
2338 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2339 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2342 #define MICRO_COMPLEX_ADD_COLS(size) \
2343 lhs_ptr_real += (remaining_rows * size); \
2344 if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * size); \
2345 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2347 #define MICRO_COMPLEX_WORK_PEEL_ROW \
2348 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
2349 Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
2350 MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \
2351 MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \
2352 MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false)
2354 #define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
2355 if (PEEL_COMPLEX_ROW > peel) { \
2356 for (Index i = 0; i < accRows; i++) { \
2357 accReal##sum.packet[i] += accReal##peel.packet[i]; \
2358 accImag##sum.packet[i] += accImag##peel.packet[i]; \
2362 #define MICRO_COMPLEX_ADD_PEEL_ROW \
2363 MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) \
2364 MICRO_COMPLEX_ADD_PEEL(1, 0)
2366 template<
typename Scalar,
typename Packet, const Index accRows,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal, const Index remaining_rows>
2368 const Scalar* &lhs_ptr_real,
const Scalar* &lhs_ptr_imag,
2369 const Scalar* &rhs_ptr_real0,
const Scalar* &rhs_ptr_real1,
const Scalar* &rhs_ptr_real2,
2370 const Scalar* &rhs_ptr_imag0,
const Scalar* &rhs_ptr_imag1,
const Scalar* &rhs_ptr_imag2,
2371 PacketBlock<Packet,accRows> &accReal, PacketBlock<Packet,accRows> &accImag)
2374 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
2378 template<
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal, const Index remaining_rows>
2380 const DataMapper&
res,
2381 const Scalar* lhs_base,
2382 const Scalar* rhs_base,
2389 const Packet& pAlphaReal,
2390 const Packet& pAlphaImag,
2391 const Packet& pMask)
2393 const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL;
2394 const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL;
2395 const Scalar* lhs_ptr_real = lhs_base +
advanceRows*
row*strideA + remaining_rows*offsetA;
2396 const Scalar* lhs_ptr_imag = NULL;
2397 if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
2399 PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
2400 PacketBlock<Packet,accRows> taccReal, taccImag;
2401 PacketBlock<Packetc,accRows> acc0, acc1;
2402 PacketBlock<Packetc,accRows*2> tRes;
2406 bsetzero<Packet, accRows>(accReal0);
2407 bsetzero<Packet, accRows>(accImag0);
2424 for(; k < depth; k++)
2426 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1, rhs_ptr_imag2, accReal0, accImag0);
2429 constexpr
bool full = (remaining_rows >
accColsC);
2430 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes,
res,
row, 0);
2431 if ((accRows == 1) || (
rows >= accCols))
2433 bscalec<Packet,accRows,true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2434 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2435 bstore<DataMapper, Packetc, accRows>(acc0,
res,
row + 0);
2440 bscalec<Packet,accRows,false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2441 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2443 if ((
sizeof(Scalar) ==
sizeof(
float)) && (remaining_rows == 1))
2445 for(
Index j = 0;
j < accRows;
j++) {
2446 res(
row + 0,
j) = pfirst<Packetc>(acc0.packet[
j]);
2449 bstore<DataMapper, Packetc, accRows>(acc0,
res,
row + 0);
2451 for(
Index j = 0;
j < accRows;
j++) {
2459 #define MICRO_COMPLEX_EXTRA_ROWS(N) \
2460 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
2462 template<
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2464 const DataMapper&
res,
2465 const Scalar* lhs_base,
2466 const Scalar* rhs_base,
2473 Index remaining_rows,
2474 const Packet& pAlphaReal,
2475 const Packet& pAlphaImag,
2476 const Packet& pMask)
2481 #define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2482 MICRO_COMPLEX_UNROLL(func2); \
2483 func(0,peel) func(1,peel) func(2,peel) func(3,peel)
2485 #define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
2486 if (unroll_factor > iter) { \
2487 pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2490 #define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2491 if (PEEL_COMPLEX > peel) { \
2492 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
2493 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
2494 MICRO_COMPLEX_BROADCAST(peel) \
2495 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2497 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2498 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2501 #define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2502 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2503 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2504 func(func1,func2,0) func(func1,func2,1) \
2505 func(func1,func2,2) func(func1,func2,3)
2507 #define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2508 Packet rhsV0[M], rhsVi0[M];\
2511 #define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \
2512 MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \
2513 MICRO_COMPLEX_ADD_ROWS(size, false)
2515 #define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX)
2517 #define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1)
2519 #define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2520 if (unroll_factor > iter) { \
2521 bsetzero<Packet, accRows>(accReal##iter); \
2522 bsetzero<Packet, accRows>(accImag##iter); \
2524 EIGEN_UNUSED_VARIABLE(accReal##iter); \
2525 EIGEN_UNUSED_VARIABLE(accImag##iter); \
2528 #define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2530 #define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2532 #define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2534 #define MICRO_COMPLEX_STORE_ONE(iter) \
2535 if (unroll_factor > iter) { \
2536 constexpr bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
2537 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter*accCols, 0); \
2538 bscalec<Packet,accRows,!(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); \
2539 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
2540 bstore<DataMapper, Packetc, accRows>(acc0, res, row + iter*accCols + 0); \
2542 bstore<DataMapper, Packetc, accRows>(acc1, res, row + iter*accCols + accColsC); \
2546 #define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2548 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper, const Index accRows, const Index accCols, const Index accCols2,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2550 const DataMapper&
res,
2551 const Scalar* lhs_base,
2552 const Scalar* rhs_base,
2558 const Packet& pAlphaReal,
2559 const Packet& pAlphaImag,
2560 const Packet& pMask)
2562 const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL;
2563 const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL;
2564 const Index imag_delta = accCols*strideA;
2565 const Index imag_delta2 = accCols2*strideA;
2566 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
2567 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
2568 PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1;
2569 PacketBlock<Packet,accRows> accReal2, accImag2, accReal3, accImag3;
2570 PacketBlock<Packet,accRows> taccReal, taccImag;
2571 PacketBlock<Packetc,accRows> acc0, acc1;
2572 PacketBlock<Packetc,accRows*2> tRes;
2585 for(; k < depth; k++)
2594 #define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
2595 gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
2598 template<
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2600 const DataMapper&
res,
2601 const Scalar* blockA,
2602 const Scalar* blockB,
2610 Index remaining_rows,
2611 const Packet& pAlphaReal,
2612 const Packet& pAlphaImag,
2613 const Packet& pMask)
2615 const DataMapper res3 =
res.getSubMapper(0,
col);
2618 const Scalar* lhs_base = blockA + accCols*offsetA;
2621 #define MAX_COMPLEX_UNROLL 4
2625 switch( (
rows-
row)/accCols ) {
2626 #if MAX_COMPLEX_UNROLL > 4
2631 #if MAX_COMPLEX_UNROLL > 3
2636 #if MAX_COMPLEX_UNROLL > 2
2641 #if MAX_COMPLEX_UNROLL > 1
2649 #undef MAX_COMPLEX_UNROLL
2651 if(remaining_rows > 0)
2653 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB,
row,
rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2657 #define MICRO_COMPLEX_EXTRA_COLS(N) \
2658 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2660 template<
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2662 const DataMapper&
res,
2663 const Scalar* blockA,
2664 const Scalar* blockB,
2673 Index remaining_rows,
2674 const Packet& pAlphaReal,
2675 const Packet& pAlphaImag,
2676 const Packet& pMask)
2681 template<
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2682 EIGEN_STRONG_INLINE
void gemm_complex(
const DataMapper&
res,
const LhsScalar* blockAc,
const RhsScalar* blockBc,
Index rows,
Index depth,
Index cols, Scalarc alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
2684 const Index remaining_rows =
rows % accCols;
2686 if( strideA == -1 ) strideA = depth;
2687 if( strideB == -1 ) strideB = depth;
2689 const Packet pAlphaReal = pset1<Packet>(alpha.real());
2690 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2691 const Packet pMask = bmask<Packet>(remaining_rows);
2693 const Scalar* blockA = (Scalar *) blockAc;
2694 const Scalar* blockB = (Scalar *) blockBc;
2699 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2704 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows,
cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2714 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2717 #if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && __has_builtin(__builtin_cpu_supports)
2718 return __builtin_cpu_supports (
"arch_3_1") && __builtin_cpu_supports (
"mma");
2728 return pmadd(acc, pAlpha, result_block);
2731 template<
bool lhsExtraRows>
2737 pstoreu(result, result_block);
2742 template<
bool rhsExtraCols,
bool lhsExtraRows>
2749 storeF32<lhsExtraRows>(result, result_block,
rows, extra_rows);
2750 }
while (++
x < extra_cols);
2753 float *result2 = result;
2760 storeF32<lhsExtraRows>(result2, result_block[
x],
rows, extra_rows);
2769 return reinterpret_cast<Packet4f>(vec_mergeh(
data, z));
2771 return reinterpret_cast<Packet4f>(vec_mergeh(z,
data));
2779 return reinterpret_cast<Packet4f>(vec_mergel(
data, z));
2781 return reinterpret_cast<Packet4f>(vec_mergel(z,
data));
2785 template<Index N, Index M>
2790 }
else if (N >= (
M*8+4)) {
2801 storeConvertTwoBF16<N, 0>(to + 0,
block, extra);
2803 storeConvertTwoBF16<N, 1>(to + 8,
block);
2806 storeConvertTwoBF16<N, 2>(to + 16,
block);
2807 storeConvertTwoBF16<N, 3>(to + 24,
block);
2811 template<
bool non_unit_str
ide, Index delta>
2814 if (non_unit_stride) {
2821 static Packet16uc p16uc_MERGE16_32_1 = { 0, 1, 16,17, 2, 3, 18,19, 0, 1, 16,17, 2, 3, 18,19 };
2822 static Packet16uc p16uc_MERGE16_32_2 = { 4, 5, 20,21, 6, 7, 22,23, 4, 5, 20,21, 6, 7, 22,23 };
2823 static Packet16uc p16uc_MERGE16_32_3 = { 8, 9, 24,25, 10,11, 26,27, 8, 9, 24,25, 10,11, 26,27 };
2824 static Packet16uc p16uc_MERGE16_32_4 = { 12,13, 28,29, 14,15, 30,31, 12,13, 28,29, 14,15, 30,31 };
2826 static Packet16uc p16uc_MERGE16_32_5 = { 0,1, 16,17, 16,17, 16,17, 0,1, 16,17, 16,17, 16,17 };
2827 static Packet16uc p16uc_MERGE16_32_6 = { 2,3, 18,19, 18,19, 18,19, 2,3, 18,19, 18,19, 18,19 };
2828 static Packet16uc p16uc_MERGE16_32_7 = { 4,5, 20,21, 20,21, 20,21, 4,5, 20,21, 20,21, 20,21 };
2829 static Packet16uc p16uc_MERGE16_32_8 = { 6,7, 22,23, 22,23, 22,23, 6,7, 22,23, 22,23, 22,23 };
2835 return reinterpret_cast<Packet4f>(vec_perm(
data, z, mask));
2837 return reinterpret_cast<Packet4f>(vec_perm(z,
data, mask));
2841 template<
bool lhsExtraRows,
bool odd, Index size>
2864 }
while (++
i < extra_rows);
2876 template<
bool lhsExtraRows>
2881 for(;
col + 4*2 <=
cols;
col += 4*2, result += 4*4*4, src += 4*
rows) {
2882 convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,4>(result,
rows, src, extra_rows);
2885 convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,1>(result,
rows, src, extra_rows);
2888 convertArrayPointerBF16toF32DupOne<lhsExtraRows,true,1>(result,
rows, src - delta, extra_rows);
2892 template<const Index size,
bool non_unit_str
ide>
2898 r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
2900 r32.packet[1] = loadBF16fromResult<non_unit_stride, 8>(src, resInc);
2903 r32.packet[2] = loadBF16fromResult<non_unit_stride, 16>(src, resInc);
2904 r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
2906 storeConvertBlockBF16<size>(result +
i, r32,
rows & 3);
2907 i += extra; src += extra*resInc;
2908 if (
size != 32)
break;
2912 template<
bool non_unit_str
ide>
2918 convertPointerBF16toF32<32, non_unit_stride>(
i, result,
rows, src2, resInc);
2919 convertPointerBF16toF32<16, non_unit_stride>(
i, result,
rows, src2, resInc);
2920 convertPointerBF16toF32<8, non_unit_stride>(
i, result,
rows, src2, resInc);
2921 convertPointerBF16toF32<4, non_unit_stride>(
i, result,
rows, src2, resInc);
2922 convertPointerBF16toF32<1, non_unit_stride>(
i, result,
rows, src2, resInc);
2926 template<Index num_acc, Index size = 4>
2931 for(
Index k = 0; k < num_acc; k++) {
2938 template<Index num_acc>
2941 for(
Index i = 0;
i < num_acc;
i++) {
2943 t0 = vec_mergeh(
reinterpret_cast<Packet4ui>(acc[
i][0]),
reinterpret_cast<Packet4ui>(acc[
i][2]));
2944 t1 = vec_mergel(
reinterpret_cast<Packet4ui>(acc[
i][0]),
reinterpret_cast<Packet4ui>(acc[
i][2]));
2945 t2 = vec_mergeh(
reinterpret_cast<Packet4ui>(acc[
i][1]),
reinterpret_cast<Packet4ui>(acc[
i][3]));
2946 t3 = vec_mergel(
reinterpret_cast<Packet4ui>(acc[
i][1]),
reinterpret_cast<Packet4ui>(acc[
i][3]));
2947 acc[
i][0] =
reinterpret_cast<Packet4f>(vec_mergeh(t0, t2));
2948 acc[
i][1] =
reinterpret_cast<Packet4f>(vec_mergel(t0, t2));
2949 acc[
i][2] =
reinterpret_cast<Packet4f>(vec_mergeh(t1, t3));
2950 acc[
i][3] =
reinterpret_cast<Packet4f>(vec_mergel(t1, t3));
2954 template<Index num_acc>
2957 for(
Index i = 0,
j = 0;
j < num_acc;
i++,
j += 2) {
2958 for(
Index x = 0,
y = 0;
x < 2;
x++,
y += 2) {
2959 for(
Index w = 0, z = 0;
w < 2;
w++, z += 2) {
2960 acc[
i][
y+
w] = acc[
j+
x][z+0] + acc[
j+
x][z+1];
2966 template<Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows, Index num_rhs>
2969 tranposeResults<num_acc>(acc);
2970 addResults<num_acc>(acc);
2972 constexpr
Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0));
2974 for(
Index i = 0;
i < real_rhs;
i++, result += 4*
rows, k++){
2975 storeResults<false, lhsExtraRows>(acc[k],
rows, pAlpha, result, extra_cols, extra_rows);
2978 storeResults<rhsExtraCols, lhsExtraRows>(acc[k],
rows, pAlpha, result, extra_cols, extra_rows);
2988 dhs1 = vec_mergel(dhs0, dhs2);
2989 dhs0 = vec_mergeh(dhs0, dhs2);
2995 template<Index num_acc,
bool zero,
bool rhsExtraCols, Index num_rhs>
2998 const float* indexA,
2999 const float* indexB,
3007 constexpr
Index num_lhs = 4;
3008 Packet4f lhs[num_lhs], rhs[num_rhs];
3010 constexpr
Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0));
3011 for(
Index i = 0;
i < real_rhs;
i += 2){
3012 loadTwoRhsFloat32<zero>(indexB + k*4, strideB,
i, rhs[
i + 0], rhs[
i + 1]);
3015 loadTwoRhsFloat32<zero>(indexB + k*extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]);
3019 for(
Index j = 0;
j < num_lhs;
j++) {
3023 for(
Index j = 0;
j < num_rhs;
j++) {
3024 for(
Index i = 0;
i < num_lhs;
i++) {
3030 template<const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
3033 constexpr
Index num_rhs = num_acc;
3037 zeroAccumulators<num_acc>(acc);
3040 for(k = 0; k + 2 <= depth; k += 2){
3041 KLoop<num_acc, false, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
3044 KLoop<num_acc, true, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
3047 outputResultsVSX<num_acc, rhsExtraCols, lhsExtraRows, num_rhs>(acc,
rows, pAlpha, result, extra_cols, extra_rows);
3051 #define MAX_BFLOAT16_ACC_VSX 4
3053 template<const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
3056 constexpr
Index step = (num_acc * 4);
3057 const Index extra_cols = (rhsExtraCols) ? (
cols & 3) : 0;
3058 const Index extra_rows = (lhsExtraRows) ? (
rows & 3) : 0;
3062 colVSXLoopBodyIter<num_acc*2, rhsExtraCols, lhsExtraRows>(depth,
rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
3064 indexB += strideB*(num_acc * 2);
3065 result +=
rows*step;
3066 }
while(multiIters && (step <=
cols - (
col += step)));
3069 template<const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
3073 colVSXLoopBody<num_acc + (rhsExtraCols ? 1 : 0), rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
3077 template<
bool rhsExtraCols,
bool lhsExtraRows>
3082 colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
3085 colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
3088 colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
3092 colVSXLoopBody<1, true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
3098 template<Index size,
bool lhsExtraRows = false>
3099 EIGEN_ALWAYS_INLINE void colVSXLoops(
Index depth,
Index cols,
Index rows,
const Packet4f pAlpha,
const bfloat16* indexA,
const float* indexA2,
const float* blockB2,
Index strideA,
Index strideB,
Index offsetB,
float* result2)
3103 convertArrayPointerBF16toF32Dup<lhsExtraRows>(
const_cast<float *
>(indexA2), strideA, delta_rows, indexA,
row,
rows & 3);
3105 const float *blockB = blockB2;
3106 float *result = result2 +
row;
3110 colVSXLoopBody<MAX_BFLOAT16_ACC_VSX, false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA2, blockB, strideB, 0, result);
3111 blockB += (strideB >> 1)*
col;
3115 colVSXLoopBodyExtra<true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA2, blockB, strideB, offsetB, result);
3117 colVSXLoopBodyExtra<false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA2, blockB, strideB, 0, result);
3122 template<Index size>
3123 EIGEN_ALWAYS_INLINE void calcVSXColLoops(
const bfloat16*& indexA,
const float* indexA2,
Index&
row,
Index depth,
Index cols,
Index rows,
const Packet4f pAlpha,
const float* indexB,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB,
Index bigSuffix,
float *result)
3126 indexA +=
size*offsetA;
3127 colVSXLoops<size>(depth,
cols,
rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result +
row);
3129 indexA += bigSuffix*
size/16;
3133 template<const Index size,
typename DataMapper>
3139 r32.packet[0] = src.template loadPacket<Packet8bf>(
i + 0);
3141 r32.packet[1] = src.template loadPacket<Packet8bf>(
i + 8);
3144 r32.packet[2] = src.template loadPacket<Packet8bf>(
i + 16);
3145 r32.packet[3] = src.template loadPacket<Packet8bf>(
i + 24);
3147 storeConvertBlockBF16<size>(result +
i, r32,
rows & 3);
3149 if (
size != 32)
break;
3153 template<
typename DataMapper>
3156 typedef typename DataMapper::LinearMapper LinearMapper;
3158 const LinearMapper src2 = src.getLinearMapper(0,
j);
3160 convertBF16toF32<32, LinearMapper>(
i, result,
rows, src2);
3161 convertBF16toF32<16, LinearMapper>(
i, result,
rows, src2);
3162 convertBF16toF32<8, LinearMapper>(
i, result,
rows, src2);
3163 convertBF16toF32<4, LinearMapper>(
i, result,
rows, src2);
3164 convertBF16toF32<1, LinearMapper>(
i, result,
rows, src2);
3173 template<
typename DataMapper, const Index size>
3176 const DataMapper res2 =
res.getSubMapper(0,
col);
3178 float *result2 = result +
col*
rows;
3181 PacketBlock<Packet8bf,size>
block;
3185 res2.template storePacketBlock<Packet8bf,size>(
row, 0,
block);
3191 res2.template storePacketPartial<Packet8bf>(
row,
j, fp16,
rows & 7);
3196 template<
typename DataMapper>
3201 convertArrayF32toBF16ColVSX<DataMapper,4>(result,
col,
rows,
res);
3206 convertArrayF32toBF16ColVSX<DataMapper,1>(result,
col,
rows,
res);
3209 convertArrayF32toBF16ColVSX<DataMapper,2>(result,
col,
rows,
res);
3212 convertArrayF32toBF16ColVSX<DataMapper,3>(result,
col,
rows,
res);
3217 template<
typename DataMapper>
3218 void gemmbfloat16(
const DataMapper&
res,
const bfloat16* indexA,
const bfloat16* indexB,
Index rows,
Index depth,
Index cols,
bfloat16 alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
3223 if( strideA == -1 ) strideA = depth;
3224 if( strideB == -1 ) strideB = depth;
3230 convertArrayBF16toF32<DataMapper>(result,
cols,
rows,
res);
3233 Index bigSuffix = 2*8*(strideA-offsetA);
3234 float* indexBF32 = indexB2 + 4*offsetB;
3241 calcVSXColLoops<16>(indexA, indexA2,
row, depth,
cols,
rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result);
3244 calcVSXColLoops<8>(indexA, indexA2,
row, depth,
cols,
rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result);
3246 calcVSXColLoops<4>(indexA, indexA2,
row, depth,
cols,
rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result);
3250 colVSXLoops<4, true>(depth,
cols,
rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB, result +
row);
3254 convertArrayF32toBF16VSX<DataMapper>(result,
cols,
rows,
res);
3257 #undef MAX_BFLOAT16_ACC_VSX
3264 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3270 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3274 dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
3275 pack(blockA, lhs, depth,
rows, stride, offset);
3278 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3279 struct gemm_pack_lhs<double,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode>
3284 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3288 dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
3289 pack(blockA, lhs, depth,
rows, stride, offset);
3292 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3293 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3294 struct gemm_pack_rhs<double,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>
3299 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3303 dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
3304 pack(blockB, rhs, depth,
cols, stride, offset);
3307 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3308 struct gemm_pack_rhs<double,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>
3313 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3317 dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
3318 pack(blockB, rhs, depth,
cols, stride, offset);
3321 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3322 struct gemm_pack_rhs<bfloat16,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>
3327 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3331 dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, false> pack;
3332 pack(blockB, rhs, depth,
cols, stride, offset);
3335 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3336 struct gemm_pack_rhs<bfloat16,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>
3341 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3345 dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, false> pack;
3346 pack(blockB, rhs, depth,
cols, stride, offset);
3350 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3351 struct gemm_pack_lhs<bfloat16,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode>
3356 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3360 dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, true> pack;
3361 pack(blockA, lhs, depth,
rows, stride, offset);
3364 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3365 struct gemm_pack_lhs<bfloat16,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode>
3370 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3374 dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, true> pack;
3375 pack(blockA, lhs, depth,
rows, stride, offset);
3378 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3379 struct gemm_pack_lhs<float,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode>
3384 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3388 dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
3389 pack(blockA, lhs, depth,
rows, stride, offset);
3392 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3393 struct gemm_pack_lhs<float,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode>
3398 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3402 dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
3403 pack(blockA, lhs, depth,
rows, stride, offset);
3406 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3407 struct gemm_pack_lhs<
std::complex<float>,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode>
3412 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3413 void gemm_pack_lhs<std::complex<float>,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode>
3414 ::operator()(std::complex<float>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride,
Index offset)
3416 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
3417 pack(blockA, lhs, depth,
rows, stride, offset);
3420 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3421 struct gemm_pack_lhs<
std::complex<float>,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode>
3426 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3427 void gemm_pack_lhs<std::complex<float>,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode>
3428 ::operator()(std::complex<float>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride,
Index offset)
3430 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
3431 pack(blockA, lhs, depth,
rows, stride, offset);
3434 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3435 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3436 struct gemm_pack_rhs<float,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>
3441 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3445 dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
3446 pack(blockB, rhs, depth,
cols, stride, offset);
3449 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3450 struct gemm_pack_rhs<float,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>
3455 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3459 dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
3460 pack(blockB, rhs, depth,
cols, stride, offset);
3464 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3465 struct gemm_pack_rhs<
std::complex<float>,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>
3470 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3471 void gemm_pack_rhs<std::complex<float>,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>
3472 ::operator()(std::complex<float>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride,
Index offset)
3474 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
3475 pack(blockB, rhs, depth,
cols, stride, offset);
3478 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3479 struct gemm_pack_rhs<
std::complex<float>,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>
3484 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3485 void gemm_pack_rhs<std::complex<float>,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>
3486 ::operator()(std::complex<float>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride,
Index offset)
3488 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
3489 pack(blockB, rhs, depth,
cols, stride, offset);
3492 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3493 struct gemm_pack_lhs<
std::complex<double>,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode>
3498 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3499 void gemm_pack_lhs<std::complex<double>,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode>
3500 ::operator()(std::complex<double>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride,
Index offset)
3502 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
3503 pack(blockA, lhs, depth,
rows, stride, offset);
3506 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3507 struct gemm_pack_lhs<
std::complex<double>,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode>
3512 template<
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3513 void gemm_pack_lhs<std::complex<double>,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode>
3514 ::operator()(std::complex<double>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride,
Index offset)
3516 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
3517 pack(blockA, lhs, depth,
rows, stride, offset);
3520 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3521 struct gemm_pack_rhs<
std::complex<double>,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>
3526 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3527 void gemm_pack_rhs<std::complex<double>,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>
3528 ::operator()(std::complex<double>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride,
Index offset)
3530 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
3531 pack(blockB, rhs, depth,
cols, stride, offset);
3534 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3535 struct gemm_pack_rhs<
std::complex<double>,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>
3540 template<
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3541 void gemm_pack_rhs<std::complex<double>,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>
3542 ::operator()(std::complex<double>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride,
Index offset)
3544 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
3545 pack(blockB, rhs, depth,
cols, stride, offset);
3549 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3550 struct gebp_kernel<float, float,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3552 typedef typename quad_traits<float>::vectortype Packet;
3553 typedef typename quad_traits<float>::rhstype RhsPacket;
3555 void operator()(
const DataMapper&
res,
const float* blockA,
const float* blockB,
3560 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3562 ::operator()(
const DataMapper&
res,
const float* blockA,
const float* blockB,
3568 static void (*gemm_function)(
const DataMapper&,
const float*,
const float*,
Index,
Index,
Index, float,
Index,
Index,
Index,
Index) =
3569 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3571 &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols> :
3573 &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
3574 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3577 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3578 struct gebp_kernel<
std::complex<float>, std::complex<float>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3581 typedef Packet2cf Packetc;
3584 void operator()(
const DataMapper&
res,
const std::complex<float>* blockA,
const std::complex<float>* blockB,
3589 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3590 void gebp_kernel<std::complex<float>, std::complex<float>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3591 ::operator()(
const DataMapper&
res,
const std::complex<float>* blockA,
const std::complex<float>* blockB,
3597 static void (*gemm_function)(
const DataMapper&,
const std::complex<float>*,
const std::complex<float>*,
3599 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3601 &
Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>,
float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
false> :
3603 &
Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>,
float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
false>;
3604 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3607 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3608 struct gebp_kernel<float,
std::complex<float>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3611 typedef Packet2cf Packetc;
3614 void operator()(
const DataMapper&
res,
const float* blockA,
const std::complex<float>* blockB,
3619 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3620 void gebp_kernel<float, std::complex<float>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3621 ::operator()(
const DataMapper&
res,
const float* blockA,
const std::complex<float>* blockB,
3627 static void (*gemm_function)(
const DataMapper&,
const float*,
const std::complex<float>*,
3629 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3631 &
Eigen::internal::gemm_complexMMA<
float, std::complex<float>, std::complex<float>,
float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
true,
false> :
3633 &
Eigen::internal::gemm_complex<
float, std::complex<float>, std::complex<float>,
float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
true,
false>;
3634 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3637 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3638 struct gebp_kernel<
std::complex<float>, float,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3641 typedef Packet2cf Packetc;
3644 void operator()(
const DataMapper&
res,
const std::complex<float>* blockA,
const float* blockB,
3649 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3650 void gebp_kernel<std::complex<float>, float,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3651 ::operator()(
const DataMapper&
res,
const std::complex<float>* blockA,
const float* blockB,
3657 static void (*gemm_function)(
const DataMapper&,
const std::complex<float>*,
const float*,
3659 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3661 &
Eigen::internal::gemm_complexMMA<std::complex<float>,
float, std::complex<float>,
float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
true> :
3663 &
Eigen::internal::gemm_complex<std::complex<float>,
float, std::complex<float>,
float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
true>;
3664 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3667 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3668 struct gebp_kernel<double, double,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3670 typedef typename quad_traits<double>::vectortype Packet;
3671 typedef typename quad_traits<double>::rhstype RhsPacket;
3673 void operator()(
const DataMapper&
res,
const double* blockA,
const double* blockB,
3678 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3680 ::operator()(
const DataMapper&
res,
const double* blockA,
const double* blockB,
3686 static void (*gemm_function)(
const DataMapper&,
const double*,
const double*,
Index,
Index,
Index, double,
Index,
Index,
Index,
Index) =
3687 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3689 &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols> :
3691 &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
3692 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3695 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3696 struct gebp_kernel<
std::complex<double>, std::complex<double>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3698 typedef quad_traits<double>::vectortype Packet;
3699 typedef Packet1cd Packetc;
3700 typedef quad_traits<double>::rhstype RhsPacket;
3702 void operator()(
const DataMapper&
res,
const std::complex<double>* blockA,
const std::complex<double>* blockB,
3707 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3708 void gebp_kernel<std::complex<double>, std::complex<double>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3709 ::operator()(
const DataMapper&
res,
const std::complex<double>* blockA,
const std::complex<double>* blockB,
3715 static void (*gemm_function)(
const DataMapper&,
const std::complex<double>*,
const std::complex<double>*,
3717 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3719 &
Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>,
double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
false> :
3721 &
Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>,
double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
false>;
3722 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3725 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3726 struct gebp_kernel<
std::complex<double>, double,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3728 typedef quad_traits<double>::vectortype Packet;
3729 typedef Packet1cd Packetc;
3730 typedef quad_traits<double>::rhstype RhsPacket;
3732 void operator()(
const DataMapper&
res,
const std::complex<double>* blockA,
const double* blockB,
3737 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3738 void gebp_kernel<std::complex<double>, double,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3739 ::operator()(
const DataMapper&
res,
const std::complex<double>* blockA,
const double* blockB,
3745 static void (*gemm_function)(
const DataMapper&,
const std::complex<double>*,
const double*,
3747 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3749 &
Eigen::internal::gemm_complexMMA<std::complex<double>,
double, std::complex<double>,
double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
true> :
3751 &
Eigen::internal::gemm_complex<std::complex<double>,
double, std::complex<double>,
double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
false,
true>;
3752 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3755 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3756 struct gebp_kernel<double,
std::complex<double>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3758 typedef quad_traits<double>::vectortype Packet;
3759 typedef Packet1cd Packetc;
3760 typedef quad_traits<double>::rhstype RhsPacket;
3762 void operator()(
const DataMapper&
res,
const double* blockA,
const std::complex<double>* blockB,
3767 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3768 void gebp_kernel<double, std::complex<double>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3769 ::operator()(
const DataMapper&
res,
const double* blockA,
const std::complex<double>* blockB,
3775 static void (*gemm_function)(
const DataMapper&,
const double*,
const std::complex<double>*,
3777 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3779 &
Eigen::internal::gemm_complexMMA<
double, std::complex<double>, std::complex<double>,
double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
true,
false> :
3781 &
Eigen::internal::gemm_complex<
double, std::complex<double>, std::complex<double>,
double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs,
true,
false>;
3782 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
3785 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3786 struct gebp_kernel<bfloat16, bfloat16,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
3788 typedef typename quad_traits<bfloat16>::vectortype Packet;
3789 typedef typename quad_traits<bfloat16>::rhstype RhsPacket;
3791 void operator()(
const DataMapper&
res,
const bfloat16* blockA,
const bfloat16* blockB,
3796 template<
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3798 ::operator()(
const DataMapper&
res,
const bfloat16* blockA,
const bfloat16* blockB,
3802 static void (*gemm_function)(
const DataMapper&,
const bfloat16*,
const bfloat16*,
Index,
Index,
Index, bfloat16,
Index,
Index,
Index,
Index) =
3803 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3805 &Eigen::internal::gemmMMAbfloat16<DataMapper> :
3807 &Eigen::internal::gemmbfloat16<DataMapper>;
3808 gemm_function(
res, blockA, blockB,
rows, depth,
cols, alpha, strideA, strideB, offsetA, offsetB);
Array< int, Dynamic, 1 > v
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL FixedBlockXpr<...,... >::Type block(Index startRow, Index startCol, NRowsType blockRows, NColsType blockCols)
RowXpr row(Index i)
This is the const version of row(). */.
ColXpr col(Index i)
This is the const version of col().
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNUSED_VARIABLE(var)
#define MICRO_COMPLEX_UNROLL_ITER(func, N)
#define MICRO_COMPLEX_UPDATE
#define EIGEN_POWER_PREFETCH(p)
#define MICRO_UNROLL_ITER(func, N)
#define MICRO_COMPLEX_EXTRA_COLS(N)
#define MICRO_COMPLEX_DST_PTR
#define MICRO_COMPLEX_PREFETCHN(N)
#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col)
#define MICRO_COMPLEX_SRC2_PTR
#define MICRO_PREFETCHN(N)
#define MICRO_WORK_PEEL_ROW
#define MICRO_EXTRA_ROWS(N)
#define MICRO_EXTRA_COLS(N)
#define MICRO_COMPLEX_ONE_PEEL4
#define MICRO_COMPLEX_PREFETCH
#define MICRO_COMPLEX_EXTRA_ROWS(N)
#define MICRO_COMPLEX_BROADCAST_EXTRA
#define MICRO_COMPLEX_WORK_PEEL_ROW
#define MICRO_ADD_PEEL_ROW
#define MICRO_ZERO_PEEL_ROW
#define MICRO_COMPLEX_ZERO_PEEL_ROW
#define MICRO_COMPLEX_ONE4
#define MICRO_COMPLEX_ADD_PEEL_ROW
#define MICRO_UNROLL_ITER2(N, M)
#define MAX_BFLOAT16_ACC_VSX
#define MICRO_COMPLEX_UNROLL_ITER2(N, M)
#define MICRO_COMPLEX_STORE
#define MICRO_COMPLEX_SRC_PTR
#define MAX_COMPLEX_UNROLL
#define MICRO_BROADCAST_EXTRA
#define MICRO_COMPLEX_ADD_COLS(size)
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Map< RowVectorXf > v2(M2.data(), M2.size())
M1<< 1, 2, 3, 4, 5, 6, 7, 8, 9;Map< RowVectorXf > v1(M1.data(), M1.size())
float bfloat16_to_float(__bfloat16_raw h)
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data)
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const float *indexA2, const float *blockB2, Index strideA, Index strideB, Index offsetB, float *result2)
EIGEN_ALWAYS_INLINE void storeResults(Packet4f(&acc)[4], Index rows, const Packet4f pAlpha, float *result, Index extra_cols, Index extra_rows)
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows)
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock< Packet, N > *accReal, PacketBlock< Packet, N > *accImag, const Packet &lhsV, Packet &lhsVi, const Packet *rhsV, const Packet *rhsVi)
EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16 *src, Index resInc)
void symm_pack_lhs_helper(Scalar *blockA, const Scalar *_lhs, Index lhsStride, Index cols, Index rows)
__vector unsigned char Packet16uc
void colVSXLoopBody(Index &col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *indexB, Index strideB, Index offsetB, float *result)
void gemmbfloat16(const DataMapper &res, const bfloat16 *indexA, const bfloat16 *indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
EIGEN_ALWAYS_INLINE void pgerc(PacketBlock< Packet, N > *accReal, PacketBlock< Packet, N > *accImag, const Scalar *lhs_ptr, const Scalar *lhs_ptr_imag, const Packet *rhsV, const Packet *rhsVi)
EIGEN_ALWAYS_INLINE void storeF32(float *&result, Packet4f result_block, Index rows, Index extra_rows)
EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f(&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float *result, const Index extra_cols, Index extra_rows)
void symm_pack_rhs_helper(Scalar *blockB, const Scalar *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock< Packet, N > &acc)
static Packet16uc p16uc_MERGE16_32_8
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data)
EIGEN_ALWAYS_INLINE void pstore_partial< bfloat16 >(bfloat16 *to, const Packet8bf &from, const Index n, const Index offset)
EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock< Packet, N > &acc, PacketBlock< Packet, N > &accZ, const Packet &pAlpha)
static Packet16uc p16uc_TRANSPOSE64_LO
EIGEN_ALWAYS_INLINE std::complex< Scalar > getAdjointVal(Index i, Index j, const_blas_data_mapper< std::complex< Scalar >, Index, StorageOrder > &dt)
EIGEN_ALWAYS_INLINE void tranposeResults(Packet4f(&acc)[num_acc][4])
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
__vector unsigned short int Packet8us
EIGEN_ALWAYS_INLINE void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *blockB, Index strideB, Index offsetB, float *result)
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(const Scalar *&lhs_ptr, const Scalar *&rhs_ptr0, const Scalar *&rhs_ptr1, const Scalar *&rhs_ptr2, PacketBlock< Packet, accRows > &accZero)
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index &i, float *result, Index rows, const DataMapper &src)
EIGEN_ALWAYS_INLINE void bstore(PacketBlock< Packet, N > &acc, const DataMapper &res, Index row)
EIGEN_ALWAYS_INLINE void pbroadcastN(const __UNPACK_TYPE__(Packet) *ap0, const __UNPACK_TYPE__(Packet) *ap1, const __UNPACK_TYPE__(Packet) *ap2, Packet &a0, Packet &a1, Packet &a2, Packet &a3)
__UNPACK_TYPE__(Packet) pfirst_common(const Packet &a)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
void pstoreu(Scalar *to, const Packet &from)
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock< Packet, N > &aReal, PacketBlock< Packet, N > &aImag, const Packet &bReal, const Packet &bImag, PacketBlock< Packet, N > &cReal, PacketBlock< Packet, N > &cImag, const Packet &pMask)
EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(const Scalar *&lhs_ptr_real, const Scalar *&lhs_ptr_imag, const Scalar *&rhs_ptr_real0, const Scalar *&rhs_ptr_real1, const Scalar *&rhs_ptr_real2, const Scalar *&rhs_ptr_imag0, const Scalar *&rhs_ptr_imag1, const Scalar *&rhs_ptr_imag2, PacketBlock< Packet, accRows > &accReal, PacketBlock< Packet, accRows > &accImag)
void gemm_complex(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
EIGEN_ALWAYS_INLINE void gemm_extra_row(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
void gemm_complexMMA(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
void symm_pack_complex_lhs_helper(std::complex< Scalar > *blockA, const std::complex< Scalar > *_lhs, Index lhsStride, Index cols, Index rows)
void pstoreu_partial(Scalar *to, const Packet &from, const Index n, const Index offset=0)
EIGEN_ALWAYS_INLINE void bload(PacketBlock< Packet, N *(Complex?2:1)> &acc, const DataMapper &res, Index row, Index col)
__vector unsigned int Packet4ui
EIGEN_ALWAYS_INLINE bool supportsMMA()
EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16 *&indexA, const float *indexA2, Index &row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexB, Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
void symm_pack_complex_rhs_helper(std::complex< Scalar > *blockB, const std::complex< Scalar > *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index rows, const bfloat16 *src, Index extra_rows)
static const Packet16uc p16uc_GETIMAG32
static const Packet16uc p16uc_GETREAL32
static Packet16uc p16uc_MERGE16_32_2
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet4f, 4, false >(const float *ap0, const float *ap1, const float *ap2, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
Packet4f ploadu< Packet4f >(const float *from)
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet4f, 4, true >(const float *ap0, const float *, const float *, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
static Packet16uc p16uc_MERGE16_32_3
EIGEN_ALWAYS_INLINE void bscale(PacketBlock< Packet, N > &acc, PacketBlock< Packet, N > &accZ, const Packet &pAlpha)
EIGEN_ALWAYS_INLINE void storeBlock(Scalar *to, PacketBlock< Packet, N > &block)
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16 *src, Index delta, Index extra_rows)
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet2d, 4, false >(const double *ap0, const double *, const double *, Packet2d &a0, Packet2d &a1, Packet2d &a2, Packet2d &a3)
Packet8us pset1< Packet8us >(const unsigned short int &from)
EIGEN_ALWAYS_INLINE void band(PacketBlock< Packet, N > &acc, const Packet &pMask)
EIGEN_ALWAYS_INLINE void addResults(Packet4f(&acc)[num_acc][4])
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f(&acc)[num_acc][size])
EIGEN_ALWAYS_INLINE void gemm_complex_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
EIGEN_ALWAYS_INLINE Packet2d bmask< Packet2d >(const Index remaining_rows)
Packet2cf pload2(const std::complex< float > &from0, const std::complex< float > &from1)
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet) *lhs)
Packet8bf ploadu< Packet8bf >(const bfloat16 *from)
void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *blockB, Index strideB, Index offsetB, float *result)
static Packet16uc p16uc_MERGE16_32_4
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock< Packet, N > &taccReal, PacketBlock< Packet, N > &taccImag, PacketBlock< Packetc, N > &acc1, PacketBlock< Packetc, N > &acc2)
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float *res)
static Packet16uc p16uc_MERGE16_32_6
Packet2d pload< Packet2d >(const double *from)
EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float *to, PacketBlock< Packet8bf,(N+7)/8 > &block, Index extra=0)
EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlpha, const Packet &pMask)
Packet2cf preverse(const Packet2cf &a)
EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float *block, Index strideB, Index i, Packet4f &dhs0, Packet4f &dhs1)
EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float *to, PacketBlock< Packet8bf,(N+7)/8 > &block, Index extra)
static Packet16uc p16uc_MERGE16_32_7
Packet4f pset1< Packet4f >(const float &from)
static Packet16uc p16uc_MERGE16_32_5
EIGEN_ALWAYS_INLINE void pger_common(PacketBlock< Packet, N > *acc, const Packet &lhsV, const Packet *rhsV)
EIGEN_ALWAYS_INLINE void gemm_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, const Packet &pAlpha, const Packet &pMask)
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock< Packet, N > &taccReal, PacketBlock< Packet, N > &taccImag, PacketBlock< Packetc, N *2 > &tRes, PacketBlock< Packetc, N > &acc1, PacketBlock< Packetc, N > &acc2)
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Index rows, const DataMapper &res)
EIGEN_ALWAYS_INLINE void KLoop(const float *indexA, const float *indexB, Packet4f(&acc)[num_acc][4], Index strideB, Index k, Index offsetB, Index extra_cols)
static Packet16uc p16uc_MERGE16_32_1
Packet8bf F32ToBf16Both(Packet4f lo, Packet4f hi)
EIGEN_ALWAYS_INLINE void pger(PacketBlock< Packet, N > *acc, const Scalar *lhs, const Packet *rhsV)
void pstore< bfloat16 >(bfloat16 *to, const Packet8bf &from)
void pbroadcast4< Packet4f >(const float *a, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
EIGEN_ALWAYS_INLINE Packet8bf pgather< bfloat16, Packet8bf >(const bfloat16 *from, Index stride)
void gemm(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index &i, float *result, Index rows, bfloat16 *&src, Index resInc)
static const Packet4i mask4[4]
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, Index rows, const DataMapper &res)
void pstore< double >(double *to, const Packet4d &from)
EIGEN_ALWAYS_INLINE void gemm_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper &src)
EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float *indexA, const float *indexB, Index strideB, Index offsetB, float *result, const Index extra_cols, const Index extra_rows)
static Packet16uc p16uc_TRANSPOSE64_HI
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_imag_op< typename Derived::Scalar >, const Derived > imag(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)