11 #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
15 #if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
16 #pragma GCC push_options
17 #pragma GCC target("cpu=power10,htm")
21 #if !__has_builtin(__builtin_vsx_assemble_pair)
22 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
24 #if !__has_builtin(__builtin_vsx_disassemble_pair)
25 #define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
29 #include "../../InternalHeaderCheck.h"
37 #define accColsC (accCols / 2)
41 __builtin_mma_xxsetaccz(acc);
44 #ifdef USE_PARTIAL_PACKETS
45 template<
typename DataMapper,
typename Packet,
bool full>
48 template<
typename DataMapper,
typename Packet, const Index accCols, const Index accCols2>
52 PacketBlock<Packet, 4> result;
53 __builtin_mma_disassemble_acc(&result.packet, acc);
55 PacketBlock<Packet, 4> tRes;
56 #ifdef USE_PARTIAL_PACKETS
59 bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes,
data,
i, 0);
60 bscale<Packet, 4>(tRes, result, alpha);
61 bstore<DataMapper, Packet, 4>(tRes,
data,
i);
63 bload_partial<DataMapper, Packet, 0, false, 4>(tRes,
data,
i, elements);
64 bscale<Packet, 4>(tRes, result, alpha);
65 bstore_partial<DataMapper, Packet, 4>(tRes,
data,
i, elements);
68 bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes,
data,
i, 0);
69 bscale<Packet, 4, (accCols != accCols2)>(tRes, result, alpha, pMask);
70 bstore<DataMapper, Packet, 4>(tRes,
data,
i);
74 template<
typename DataMapper,
typename Packet,
typename Packetc, const Index accCols, const Index accCols2>
77 constexpr
bool full = (accCols2 >
accColsC);
78 PacketBlock<Packet, 4> resultReal, resultImag;
79 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
80 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
82 PacketBlock<Packetc, 8> tRes;
83 bload<DataMapper, Packetc, accColsC, ColMajor, true, 4, full>(tRes,
data,
i, 0);
85 PacketBlock<Packet, 4> taccReal, taccImag;
86 bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
88 PacketBlock<Packetc, 4> acc1, acc2;
89 bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
91 bstore<DataMapper, Packetc, 4>(acc1,
data,
i);
98 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
101 if(NegativeAccumulate)
103 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
105 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
109 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
112 if(NegativeAccumulate)
114 __builtin_mma_xvf64gernp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
116 __builtin_mma_xvf64gerpp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
120 template<
typename Packet,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
121 EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag,
const Packet& lhsV, Packet& lhsVi,
const RhsPacket& rhsV, RhsPacket& rhsVi)
123 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
125 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
129 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
130 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
134 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
139 template<
typename Packet>
142 return ploadu<Packet>(rhs);
145 template<
typename Scalar,
typename Packet>
148 rhsV = ploadRhs<Packet>(rhs);
155 __builtin_vsx_assemble_pair(&rhsV,
156 reinterpret_cast<__vector
unsigned char>(ploadRhs<Packet2d>(rhs + (
sizeof(
Packet2d) /
sizeof(
double)))),
157 reinterpret_cast<__vector
unsigned char>(ploadRhs<Packet2d>(rhs)));
159 rhsV = *
reinterpret_cast<__vector_pair *
>(
const_cast<double *
>(rhs));
168 #if (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
169 #define VECTOR_PAIR_LOADS_LHS
175 #define MICRO_MMA_UNROLL(func) \
176 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
178 #define MICRO_MMA_WORK(func, type, peel) \
179 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
180 func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel)
182 #define MICRO_MMA_WORK_ONE(iter, type, peel) \
183 if (unroll_factor > iter) { \
184 pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV##iter); \
187 #ifdef VECTOR_PAIR_LOADS_LHS
188 #define MICRO_MMA_WORK_TWO(iter, type, peel) \
189 if (unroll_factor > iter) { \
190 pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV2##iter.packet[peel & 1]); \
193 #define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \
194 if (unroll_factor > iter) { \
195 if (MICRO_NORMAL(iter)) { \
196 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##iter), plhsV##iter); \
197 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##iter.packet), &plhsV##iter); \
198 lhs_ptr##iter += accCols*2; \
200 lhsV2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr##iter); \
201 lhsV2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr##iter + accCols2); \
202 lhs_ptr##iter += accCols2*2; \
203 EIGEN_UNUSED_VARIABLE(plhsV##iter) \
206 EIGEN_UNUSED_VARIABLE(lhsV2##iter); \
207 EIGEN_UNUSED_VARIABLE(plhsV##iter) \
210 #define MICRO_MMA_LOAD_TWO(iter) MICRO_MMA_LOAD1_TWO(lhs_ptr, iter)
213 #define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
214 if (PEEL_MMA > peel) { \
215 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
216 ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV[peel]); \
217 MICRO_MMA_UNROLL(funcl) \
218 MICRO_MMA_WORK(funcw, type, peel) \
221 #ifndef VECTOR_PAIR_LOADS_LHS
222 #define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
224 MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,1) \
225 MICRO_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,3) \
226 MICRO_MMA_TYPE_PEEL(funcw,funcl,type,4) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,5) \
227 MICRO_MMA_TYPE_PEEL(funcw,funcl,type,6) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,7)
229 #define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
230 if (PEEL_MMA > peel2) { \
231 PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
232 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
233 if (sizeof(type) == 16) { \
234 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr + (accRows * peel1)), prhsV##peel1); \
235 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
237 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
238 ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV[peel1]); \
239 ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV[peel2]); \
241 MICRO_MMA_UNROLL(funcl2) \
242 MICRO_MMA_WORK(funcw2, type, peel1) \
243 MICRO_MMA_WORK(funcw2, type, peel2) \
245 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
246 MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
249 #define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
251 __vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
252 MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
253 MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \
254 MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,4,5) \
255 MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,6,7)
258 #define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
260 MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0)
262 #define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
263 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
264 rhs_ptr += (accRows * size);
266 #ifndef VECTOR_PAIR_LOADS_LHS
267 #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA)
269 #define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
270 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
271 rhs_ptr += (accRows * size);
273 #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
276 #define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
278 #define MICRO_MMA_DST_PTR_ONE(iter) \
279 if (unroll_factor > iter) { \
280 bsetzeroMMA(&accZero##iter); \
282 EIGEN_UNUSED_VARIABLE(accZero##iter); \
285 #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
287 #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_SRC_PTR_ONE)
289 #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
291 #ifdef USE_PARTIAL_PACKETS
292 #define MICRO_MMA_STORE_ONE(iter) \
293 if (unroll_factor > iter) { \
294 storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(iter)>(row + iter*accCols, res, pAlpha, accCols2, &accZero##iter); \
297 #define MICRO_MMA_STORE_ONE(iter) \
298 if (unroll_factor > iter) { \
299 storeAccumulator<DataMapper, Packet, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \
303 #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
305 #ifdef USE_PARTIAL_PACKETS
306 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols,
bool full>
308 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols, const Index accCols2>
311 const DataMapper&
res,
312 const Scalar* lhs_base,
313 const Scalar* rhs_base,
318 const Packet& pAlpha,
319 #ifdef USE_PARTIAL_PACKETS
326 const Scalar* rhs_ptr = rhs_base;
327 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
328 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
340 for(; k < depth; k++)
349 #ifdef USE_PARTIAL_PACKETS
350 #define MICRO_MMA_UNROLL_ITER2(N, M) \
351 gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, M ? remaining_rows : accCols); \
354 #define MICRO_MMA_UNROLL_ITER2(N, M) \
355 gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \
359 template<
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols>
361 const DataMapper&
res,
362 const Scalar* blockA,
363 const Scalar* blockB,
371 Index remaining_rows,
372 const Packet& pAlpha,
375 const DataMapper res3 =
res.getSubMapper(0,
col);
377 const Scalar* rhs_base = blockB +
col*strideB + accRows*offsetB;
378 const Scalar* lhs_base = blockA + accCols*offsetA;
381 #define MAX_MMA_UNROLL 7
386 #if MAX_MMA_UNROLL > 7
391 #if MAX_MMA_UNROLL > 6
396 #if MAX_MMA_UNROLL > 5
401 #if MAX_MMA_UNROLL > 4
406 #if MAX_MMA_UNROLL > 3
411 #if MAX_MMA_UNROLL > 2
416 #if MAX_MMA_UNROLL > 1
424 #undef MAX_MMA_UNROLL
426 if(remaining_rows > 0)
428 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB,
row,
rows, remaining_rows, pAlpha, pMask);
432 template<
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols>
433 void gemmMMA(
const DataMapper&
res,
const Scalar* blockA,
const Scalar* blockB,
Index rows,
Index depth,
Index cols, Scalar alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
435 const Index remaining_rows =
rows % accCols;
437 if( strideA == -1 ) strideA = depth;
438 if( strideB == -1 ) strideB = depth;
440 const Packet pAlpha = pset1<Packet>(alpha);
441 const Packet pMask = bmask<Packet>(remaining_rows);
443 typedef typename std::conditional_t<(
sizeof(Scalar) ==
sizeof(
float)), RhsPacket, __vector_pair> RhsPacket2;
448 gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows, remaining_rows, pAlpha, pMask);
453 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows,
cols, remaining_rows, pAlpha, pMask);
457 #define advanceRows ((LhsIsReal) ? 1 : 2)
458 #define advanceCols ((RhsIsReal) ? 1 : 2)
461 #define PEEL_COMPLEX_MMA 3
463 #define MICRO_COMPLEX_MMA_UNROLL(func) \
464 func(0) func(1) func(2) func(3)
466 #define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
467 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel)
469 #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
470 if (unroll_factor > iter) { \
471 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV[peel], rhsVi[peel]); \
474 #ifdef VECTOR_PAIR_LOADS_LHS
475 #define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \
476 if (unroll_factor > iter) { \
477 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV[peel], rhsVi[peel]); \
480 #define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \
481 if (!LhsIsReal && (unroll_factor > iter)) { \
482 if (MICRO_NORMAL(iter)) { \
483 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##iter + imag_delta), plhsVi##iter); \
484 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##iter.packet), &plhsVi##iter); \
486 lhsVi2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2); \
487 lhsVi2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2 + accCols2); \
488 EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
491 EIGEN_UNUSED_VARIABLE(lhsVi2##iter); \
492 EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
494 MICRO_MMA_LOAD1_TWO(lhs_ptr_real, iter)
496 #define MICRO_COMPLEX_MMA_LOAD_TWO(iter) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter)
499 #define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
500 if (PEEL_COMPLEX_MMA > peel) { \
501 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
502 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
503 ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV[peel]); \
505 ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi[peel]); \
507 MICRO_COMPLEX_MMA_UNROLL(funcl) \
508 MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
511 #ifndef VECTOR_PAIR_LOADS_LHS
512 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
513 type rhsV[4], rhsVi[4]; \
514 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,1) \
515 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,3)
517 #define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
518 if (PEEL_COMPLEX_MMA > peel2) { \
519 PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23; \
520 PacketBlock<Packet,2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
521 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
522 __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
523 if (sizeof(type) == 16) { \
524 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real + (accRows * peel1)), prhsV##peel1); \
525 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
527 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag + (accRows * peel1)), prhsVi##peel1); \
528 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi[peel1]), &prhsVi##peel1); \
530 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
533 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
534 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
535 ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV[peel1]); \
536 ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV[peel2]); \
538 ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi[peel1]); \
539 ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi[peel2]); \
542 MICRO_COMPLEX_MMA_UNROLL(funcl2) \
543 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
544 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
546 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
547 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
548 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
551 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
552 type rhsV[4], rhsVi[4]; \
553 __vector_pair prhsV0, prhsV2; \
554 __vector_pair prhsVi0, prhsVi2; \
555 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
556 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3)
559 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
560 type rhsV[1], rhsVi[1]; \
561 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0)
563 #define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
564 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
565 rhs_ptr_real += (accRows * size); \
566 if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
568 #ifndef VECTOR_PAIR_LOADS_LHS
569 #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA)
571 #define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
572 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
573 rhs_ptr_real += (accRows * size); \
574 if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
576 #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
579 #define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
581 #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
582 if (unroll_factor > iter) { \
583 bsetzeroMMA(&accReal##iter); \
584 bsetzeroMMA(&accImag##iter); \
586 EIGEN_UNUSED_VARIABLE(accReal##iter); \
587 EIGEN_UNUSED_VARIABLE(accImag##iter); \
590 #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
592 #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
594 #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
596 #define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
597 if (unroll_factor > iter) { \
598 storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
601 #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
603 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols, const Index accCols2,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
605 const DataMapper&
res,
606 const Scalar* lhs_base,
607 const Scalar* rhs_base,
613 const Packet& pAlphaReal,
614 const Packet& pAlphaImag,
617 const Scalar* rhs_ptr_real = rhs_base;
618 const Scalar* rhs_ptr_imag = NULL;
619 const Index imag_delta = accCols*strideA;
620 const Index imag_delta2 = accCols2*strideA;
622 rhs_ptr_imag = rhs_base + accRows*strideB;
626 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
627 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
628 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
643 for(; k < depth; k++)
652 #define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
653 gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
656 template<
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
658 const DataMapper&
res,
659 const Scalar* blockA,
660 const Scalar* blockB,
668 Index remaining_rows,
669 const Packet& pAlphaReal,
670 const Packet& pAlphaImag,
673 const DataMapper res3 =
res.getSubMapper(0,
col);
675 const Scalar* rhs_base = blockB +
advanceCols*
col*strideB + accRows*offsetB;
676 const Scalar* lhs_base = blockA + accCols*offsetA;
679 #define MAX_COMPLEX_MMA_UNROLL 4
684 #if MAX_COMPLEX_MMA_UNROLL > 4
689 #if MAX_COMPLEX_MMA_UNROLL > 3
694 #if MAX_COMPLEX_MMA_UNROLL > 2
699 #if MAX_COMPLEX_MMA_UNROLL > 1
707 #undef MAX_COMPLEX_MMA_UNROLL
709 if(remaining_rows > 0)
711 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB,
row,
rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
715 template<
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
716 void gemm_complexMMA(
const DataMapper&
res,
const LhsScalar* blockAc,
const RhsScalar* blockBc,
Index rows,
Index depth,
Index cols, Scalarc alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
718 const Index remaining_rows =
rows % accCols;
720 if( strideA == -1 ) strideA = depth;
721 if( strideB == -1 ) strideB = depth;
723 const Packet pAlphaReal = pset1<Packet>(alpha.real());
724 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
725 const Packet pMask = bmask<Packet>(remaining_rows);
727 const Scalar* blockA = (Scalar *) blockAc;
728 const Scalar* blockB = (Scalar *) blockBc;
730 typedef typename std::conditional_t<(
sizeof(Scalar) ==
sizeof(
float)), RhsPacket, __vector_pair> RhsPacket2;
735 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
740 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows,
cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
752 #if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
753 #pragma GCC pop_options
RowXpr row(Index i)
This is the const version of row(). */.
ColXpr col(Index i)
This is the const version of col().
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNUSED_VARIABLE(var)
#define MICRO_COMPLEX_UNROLL_ITER(func, N)
#define MICRO_COMPLEX_UPDATE
#define EIGEN_POWER_PREFETCH(p)
#define MICRO_UNROLL_ITER(func, N)
#define MICRO_MMA_DST_PTR
#define MICRO_COMPLEX_MMA_ONE
#define MICRO_COMPLEX_MMA_DST_PTR
#define MICRO_COMPLEX_MMA_SRC_PTR
#define MICRO_MMA_PREFETCH
#define MICRO_COMPLEX_MMA_STORE
#define MAX_COMPLEX_MMA_UNROLL
#define MICRO_MMA_SRC_PTR
#define MICRO_COMPLEX_MMA_ONE_PEEL
#define MICRO_MMA_ONE_PEEL
#define MICRO_MMA_UNROLL_ITER2(N, M)
#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M)
#define MICRO_COMPLEX_MMA_PREFETCH
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad *acc)
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper &data, const Packet &alpha, const Packet &pMask, __vector_quad *acc)
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper &data, const Packet &alphaReal, const Packet &alphaImag, const Packet &pMask, __vector_quad *accReal, __vector_quad *accImag)
EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad *accReal, __vector_quad *accImag, const Packet &lhsV, Packet &lhsVi, const RhsPacket &rhsV, RhsPacket &rhsVi)
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index &row, const Packet &pAlpha, const Packet &pMask)
__UNPACK_TYPE__(Packet) pfirst_common(const Packet &a)
void gemmMMA(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
void gemm_complexMMA(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
EIGEN_ALWAYS_INLINE void gemmMMA_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double *lhs, __vector_pair &lhsV)
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar *rhs, Packet &rhsV)
EIGEN_ALWAYS_INLINE Packet ploadRhs(const __UNPACK_TYPE__(Packet) *rhs)
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad *acc, const RhsPacket &a, const LhsPacket &b)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.