10 #ifndef EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
11 #define EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
13 #include "../../InternalHeaderCheck.h"
15 #if defined(__MMA__) && !EIGEN_ALTIVEC_DISABLE_MMA
16 #if EIGEN_COMP_LLVM || (__GNUC__ > 10 || __GNUC_MINOR__ >= 3)
20 #if !EIGEN_COMP_LLVM && (__GNUC__ < 11)
22 #define GCC_ONE_VECTORPAIR_BUG
29 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
30 #define EIGEN_POWER_GEMV_PREFETCH(p) prefetch(p)
32 #define EIGEN_POWER_GEMV_PREFETCH(p)
36 #if !__has_builtin(__builtin_vsx_assemble_pair)
37 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
39 #if !__has_builtin(__builtin_vsx_disassemble_pair)
40 #define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
45 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
46 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
49 #if (__GNUC_MINOR__ > 3)
50 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
51 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
53 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
54 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
57 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
58 __builtin_vsx_build_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
62 #define GEMV_IS_COMPLEX_COMPLEX ((sizeof(LhsPacket) == 16) && (sizeof(RhsPacket) == 16))
63 #define GEMV_IS_FLOAT (ResPacketSize == (16 / sizeof(float)))
64 #define GEMV_IS_SCALAR (sizeof(ResPacket) != 16)
65 #define GEMV_IS_COMPLEX_FLOAT (ResPacketSize == (16 / sizeof(std::complex<float>)))
68 template<
typename ResPacket,
typename ResScalar>
74 template<
typename ResScalar>
80 #define GEMV_UNROLL(func, N) \
81 func(0, N) func(1, N) func(2, N) func(3, N) \
82 func(4, N) func(5, N) func(6, N) func(7, N)
84 #define GEMV_UNROLL_HALF(func, N) \
85 func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
87 #define GEMV_GETN(N) (((N) * ResPacketSize) >> 2)
89 #define GEMV_LOADPACKET_COL(iter) \
90 lhs.template load<LhsPacket, LhsAlignment>(i + ((iter) * LhsPacketSize), j)
93 #define GEMV_UNROLL3(func, N, which) \
94 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) \
95 func(4, N, which) func(5, N, which) func(6, N, which) func(7, N, which)
97 #define GEMV_UNUSED_VAR(iter, N, which) \
98 if (GEMV_GETN(N) <= iter) { \
99 EIGEN_UNUSED_VARIABLE(which##iter); \
102 #define GEMV_UNUSED_EXTRA_VAR(iter, N, which) \
104 EIGEN_UNUSED_VARIABLE(which##iter); \
107 #define GEMV_UNUSED_EXTRA(N, which) \
108 GEMV_UNROLL3(GEMV_UNUSED_EXTRA_VAR, N, which)
110 #define GEMV_UNUSED(N, which) \
111 GEMV_UNROLL3(GEMV_UNUSED_VAR, N, which)
113 #define GEMV_INIT_MMA(iter, N) \
114 if (GEMV_GETN(N) > iter) { \
115 __builtin_mma_xxsetaccz(&e##iter); \
119 #define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
120 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_COL(iter2), GEMV_LOADPACKET_COL((iter2) + 1));
122 #define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
123 const LhsScalar& src##iter1 = lhs(i + ((iter1 * 32) / sizeof(LhsScalar)), j); \
124 b##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src##iter1));
127 #define GEMV_LOAD1A_COL_MMA(iter, N) \
128 if (GEMV_GETN(N) > iter) { \
129 if (GEMV_IS_FLOAT) { \
130 g##iter = GEMV_LOADPACKET_COL(iter); \
131 EIGEN_UNUSED_VARIABLE(b##iter); \
133 GEMV_LOADPAIR_COL_MMA(iter, iter << 1) \
134 EIGEN_UNUSED_VARIABLE(g##iter); \
137 EIGEN_UNUSED_VARIABLE(b##iter); \
138 EIGEN_UNUSED_VARIABLE(g##iter); \
141 #define GEMV_WORK1A_COL_MMA(iter, N) \
142 if (GEMV_GETN(N) > iter) { \
143 if (GEMV_IS_FLOAT) { \
144 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, a0, g##iter); \
146 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, b##iter, a0); \
150 #define GEMV_LOAD1B_COL_MMA(iter1, iter2, iter3, N) \
151 if (GEMV_GETN(N) > iter1) { \
152 if (GEMV_IS_FLOAT) { \
153 GEMV_LOADPAIR_COL_MMA(iter2, iter2) \
154 EIGEN_UNUSED_VARIABLE(b##iter3); \
156 GEMV_LOADPAIR_COL_MMA(iter2, iter2 << 1) \
157 GEMV_LOADPAIR_COL_MMA(iter3, iter3 << 1) \
160 EIGEN_UNUSED_VARIABLE(b##iter2); \
161 EIGEN_UNUSED_VARIABLE(b##iter3); \
163 EIGEN_UNUSED_VARIABLE(g##iter2); \
164 EIGEN_UNUSED_VARIABLE(g##iter3);
166 #define GEMV_WORK1B_COL_MMA(iter1, iter2, iter3, N) \
167 if (GEMV_GETN(N) > iter1) { \
168 if (GEMV_IS_FLOAT) { \
170 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(h), &b##iter2); \
171 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, a0, h[0]); \
172 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, a0, h[1]); \
174 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, b##iter2, a0); \
175 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, b##iter3, a0); \
180 #define GEMV_LOAD_COL_MMA(N) \
181 if (GEMV_GETN(N) > 1) { \
182 GEMV_UNROLL_HALF(GEMV_LOAD1B_COL_MMA, (N >> 1)) \
184 GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N) \
187 #define GEMV_WORK_COL_MMA(N) \
188 if (GEMV_GETN(N) > 1) { \
189 GEMV_UNROLL_HALF(GEMV_WORK1B_COL_MMA, (N >> 1)) \
191 GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N) \
194 #define GEMV_LOAD_COL_MMA(N) \
195 GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N)
197 #define GEMV_WORK_COL_MMA(N) \
198 GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N)
201 #define GEMV_DISASSEMBLE_MMA(iter, N) \
202 if (GEMV_GETN(N) > iter) { \
203 __builtin_mma_disassemble_acc(&result##iter.packet, &e##iter); \
204 if (!GEMV_IS_FLOAT) { \
205 result##iter.packet[0][1] = result##iter.packet[1][0]; \
206 result##iter.packet[2][1] = result##iter.packet[3][0]; \
210 #define GEMV_LOADPAIR2_COL_MMA(iter1, iter2) \
211 b##iter1 = *reinterpret_cast<__vector_pair *>(res + i + ((iter2) * ResPacketSize));
213 #define GEMV_LOAD2_COL_MMA(iter1, iter2, iter3, N) \
214 if (GEMV_GETN(N) > iter1) { \
215 if (GEMV_IS_FLOAT) { \
216 GEMV_LOADPAIR2_COL_MMA(iter2, iter2); \
217 EIGEN_UNUSED_VARIABLE(b##iter3); \
219 GEMV_LOADPAIR2_COL_MMA(iter2, iter2 << 1); \
220 GEMV_LOADPAIR2_COL_MMA(iter3, iter3 << 1); \
223 EIGEN_UNUSED_VARIABLE(b##iter2); \
224 EIGEN_UNUSED_VARIABLE(b##iter3); \
228 #define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
229 ResPacket f##iter2[2]; \
230 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(f##iter2), &b##iter2); \
231 f##iter2[0] = pmadd(result##iter2.packet[0], palpha, f##iter2[0]); \
232 f##iter2[1] = pmadd(result##iter3.packet[(iter2 == iter3) ? 2 : 0], palpha, f##iter2[1]); \
233 GEMV_BUILDPAIR_MMA(b##iter2, f##iter2[0], f##iter2[1]);
235 #define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
236 if (GEMV_IS_FLOAT) { \
237 __asm__ ("xvmaddasp %0,%x1,%x3\n\txvmaddasp %L0,%x2,%x3" : "+&d" (b##iter2) : "wa" (result##iter3.packet[0]), "wa" (result##iter2.packet[0]), "wa" (palpha)); \
239 __asm__ ("xvmaddadp %0,%x1,%x3\n\txvmaddadp %L0,%x2,%x3" : "+&d" (b##iter2) : "wa" (result##iter2.packet[2]), "wa" (result##iter2.packet[0]), "wa" (palpha)); \
243 #define GEMV_WORK2_COL_MMA(iter1, iter2, iter3, N) \
244 if (GEMV_GETN(N) > iter1) { \
245 if (GEMV_IS_FLOAT) { \
246 GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter2); \
248 GEMV_WORKPAIR2_COL_MMA(iter2, iter2, iter2 << 1); \
249 GEMV_WORKPAIR2_COL_MMA(iter3, iter3, iter3 << 1); \
253 #define GEMV_STOREPAIR2_COL_MMA(iter1, iter2) \
254 *reinterpret_cast<__vector_pair *>(res + i + ((iter2) * ResPacketSize)) = b##iter1;
256 #define GEMV_STORE_COL_MMA(iter, N) \
257 if (GEMV_GETN(N) > iter) { \
258 if (GEMV_IS_FLOAT) { \
259 storeMaddData<ResPacket, ResScalar>(res + i + (iter * ResPacketSize), palpha, result##iter.packet[0]); \
261 GEMV_LOADPAIR2_COL_MMA(iter, iter << 1) \
262 GEMV_WORKPAIR2_COL_MMA(iter, iter, iter << 1) \
263 GEMV_STOREPAIR2_COL_MMA(iter, iter << 1) \
267 #define GEMV_STORE2_COL_MMA(iter1, iter2, iter3, N) \
268 if (GEMV_GETN(N) > iter1) { \
269 if (GEMV_IS_FLOAT) { \
270 GEMV_STOREPAIR2_COL_MMA(iter2, iter2); \
272 GEMV_STOREPAIR2_COL_MMA(iter2, iter2 << 1) \
273 GEMV_STOREPAIR2_COL_MMA(iter3, iter3 << 1) \
277 #define GEMV_PROCESS_COL_ONE_MMA(N) \
278 GEMV_UNROLL(GEMV_INIT_MMA, N) \
280 __vector_pair b0, b1, b2, b3, b4, b5, b6, b7; \
282 LhsPacket g0, g1, g2, g3, g4, g5, g6, g7; \
283 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
284 GEMV_UNROLL(GEMV_PREFETCH, N) \
285 GEMV_LOAD_COL_MMA(N) \
286 GEMV_WORK_COL_MMA(N) \
287 } while (++j < jend); \
288 GEMV_UNROLL(GEMV_DISASSEMBLE_MMA, N) \
289 if (GEMV_GETN(N) <= 1) { \
290 GEMV_UNROLL(GEMV_STORE_COL_MMA, N) \
292 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_MMA, (N >> 1)) \
293 GEMV_UNROLL_HALF(GEMV_WORK2_COL_MMA, (N >> 1)) \
294 GEMV_UNROLL_HALF(GEMV_STORE2_COL_MMA, (N >> 1)) \
296 i += (ResPacketSize * N);
299 #define GEMV_INIT(iter, N) \
301 c##iter = pset1<ResPacket>(ResScalar(0)); \
303 EIGEN_UNUSED_VARIABLE(c##iter); \
306 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
307 #define GEMV_PREFETCH(iter, N) \
308 if (GEMV_GETN(N) > ((iter >> 1) + ((N >> 1) * (iter & 1)))) { \
309 lhs.prefetch(i + (iter * LhsPacketSize) + prefetch_dist, j); \
312 #define GEMV_PREFETCH(iter, N)
315 #define GEMV_WORK_COL(iter, N) \
317 c##iter = pcj.pmadd(GEMV_LOADPACKET_COL(iter), a0, c##iter); \
320 #define GEMV_STORE_COL(iter, N) \
322 pstoreu(res + i + (iter * ResPacketSize), pmadd(c##iter, palpha, ploadu<ResPacket>(res + i + (iter * ResPacketSize)))); \
326 #define GEMV_PROCESS_COL_ONE(N) \
327 GEMV_UNROLL(GEMV_INIT, N) \
330 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
331 GEMV_UNROLL(GEMV_PREFETCH, N) \
332 GEMV_UNROLL(GEMV_WORK_COL, N) \
333 } while (++j < jend); \
334 GEMV_UNROLL(GEMV_STORE_COL, N) \
335 i += (ResPacketSize * N);
338 #define GEMV_PROCESS_COL(N) \
339 GEMV_PROCESS_COL_ONE_MMA(N)
341 #define GEMV_PROCESS_COL(N) \
342 GEMV_PROCESS_COL_ONE(N)
347 template<
typename LhsPacket,
typename RhsPacket,
bool accumulate>
352 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
356 __builtin_mma_xvf32ger(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
361 template<
typename LhsPacket,
typename RhsPacket,
bool accumulate>
366 __builtin_mma_xvf64gerpp(acc,
a, (__vector
unsigned char)
b);
370 __builtin_mma_xvf64ger(acc,
a, (__vector
unsigned char)
b);
375 template<
typename LhsScalar,
typename LhsMapper,
typename RhsScalar,
typename RhsMapper,
typename ResScalar>
378 const LhsMapper& alhs,
379 const RhsMapper& rhs,
383 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
385 typedef typename Traits::LhsPacket LhsPacket;
386 typedef typename Traits::RhsPacket RhsPacket;
387 typedef typename Traits::ResPacket ResPacket;
397 conj_helper<LhsScalar, RhsScalar, false, false> cj;
398 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
400 const Index lhsStride = lhs.stride();
404 ResPacketSize = Traits::ResPacketSize,
405 LhsPacketSize = Traits::LhsPacketSize,
406 RhsPacketSize = Traits::RhsPacketSize,
409 #ifndef GCC_ONE_VECTORPAIR_BUG
410 const Index n8 =
rows - 8 * ResPacketSize + 1;
411 const Index n4 =
rows - 4 * ResPacketSize + 1;
412 const Index n2 =
rows - 2 * ResPacketSize + 1;
414 const Index n1 =
rows - 1 * ResPacketSize + 1;
415 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
416 const Index prefetch_dist = 64 * LhsPacketSize;
420 const Index block_cols =
cols < 128 ?
cols : (lhsStride *
sizeof(LhsScalar) < 16000 ? 16 : 8);
421 ResPacket palpha = pset1<ResPacket>(alpha);
423 for (
Index j2 = 0; j2 <
cols; j2 += block_cols)
427 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
429 __vector_quad e0, e1, e2, e3, e4, e5, e6, e7;
430 PacketBlock<ResPacket, 4> result0, result1, result2, result3, result4, result5, result6, result7;
432 GEMV_UNUSED(8, result)
433 GEMV_UNUSED_EXTRA(1,
c)
435 #ifndef GCC_ONE_VECTORPAIR_BUG
460 d0 += cj.pmul(lhs(
i,
j), rhs2(
j, 0));
461 }
while (++
j < jend);
462 res[
i] += alpha * d0;
467 template<
bool extraRows>
471 d0 =
pmadd(acc, pAlpha, d0);
479 template<Index num_acc,
bool extraRows, Index size>
482 constexpr
Index real_acc = (num_acc - (extraRows ? 1 : 0));
483 for(
Index k = 0; k < real_acc; k++) {
484 outputVecCol<false>(acc[k][0], result + k*4, pAlpha, extra_rows);
487 outputVecCol<true>(acc[real_acc][0], result + real_acc*4, pAlpha, extra_rows);
491 static Packet16uc p16uc_MERGE16_32_V1 = { 0, 1, 16,17, 0, 1, 16,17, 0, 1, 16,17, 0, 1, 16,17 };
492 static Packet16uc p16uc_MERGE16_32_V2 = { 2, 3, 18,19, 2, 3, 18,19, 2, 3, 18,19, 2, 3, 18,19 };
494 template<Index num_acc,
typename LhsMapper,
bool zero>
497 Packet8bf c0 = lhs.template loadPacket<Packet8bf>(k*4, 0);
500 b1 = lhs.template loadPacket<Packet8bf>(k*4, 1);
506 if (num_acc > (k + 1)) {
514 template<Index num_acc,
bool zero>
517 for(
Index k = 0; k < num_acc; k++) {
518 for(
Index i = 0;
i < (zero ? 1 : 2);
i++) {
519 acc[k][
i] =
pmadd(b0[
i], a0[k][
i], acc[k][
i]);
524 template<
typename RhsMapper,
bool linear>
534 template<
typename RhsMapper>
540 return rhs.template loadPacket<Packet8bf>(
j + 0);
544 template<
typename RhsMapper,
bool linear>
550 template<Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool zero,
bool linear>
554 Packet8bf b2 = loadColData<RhsMapper, linear>(rhs,
j);
561 LhsMapper lhs2 = lhs.getSubMapper(0,
j);
562 for(
Index k = 0; k < num_acc; k += 2) {
563 loadVecLoopVSX<num_acc, LhsMapper, zero>(k, lhs2, a0);
566 multVecVSX<num_acc, zero>(acc, a0, b0);
569 template<Index num_acc>
572 for(
Index i = 0;
i < num_acc;
i++) {
573 acc[
i][0] = acc[
i][0] + acc[
i][1];
578 #define MAX_BFLOAT16_VEC_ACC_VSX 8
580 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
583 constexpr
Index step = (num_acc * 4);
584 const Index extra_rows = (extraRows) ? (
rows & 3) : 0;
590 zeroAccumulators<num_acc, 2>(acc);
592 LhsMapper lhs2 = lhs.getSubMapper(
row, 0);
593 for(
Index j = 0;
j + 2 <= cend;
j += 2) {
594 vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false, linear>(
j, lhs2, rhs, acc);
597 vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
600 addResultsVSX<num_acc>(acc);
602 outputVecColResults<num_acc, extraRows, 2>(acc, result, pAlpha, extra_rows);
605 }
while(multiIters && (step <=
rows - (
row += step)));
608 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
612 colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
616 template<
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
621 colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
624 colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
627 colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
630 colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
633 colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
636 colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
639 colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
643 colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
649 template<
typename LhsMapper,
typename RhsMapper,
bool linear>
654 colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
658 colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
660 colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
664 template<const Index size,
bool inc, Index delta>
682 template<const Index size,
bool inc = false>
696 storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc,
rows & 7);
698 storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
701 storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
702 storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
704 i += extra; dst += extra*resInc;
705 if (
size != 32)
break;
709 template<
bool inc = false>
713 convertPointerF32toBF16VSX<32,inc>(
i, result,
rows, dst, resInc);
714 convertPointerF32toBF16VSX<16,inc>(
i, result,
rows, dst, resInc);
715 convertPointerF32toBF16VSX<8,inc>(
i, result,
rows, dst, resInc);
716 convertPointerF32toBF16VSX<1,inc>(
i, result,
rows, dst, resInc);
719 template<
typename LhsMapper,
typename RhsMapper>
722 const LhsMapper& alhs,
723 const RhsMapper& rhs,
727 typedef typename RhsMapper::LinearMapper LinearMapper;
737 const Index lhsStride = lhs.stride();
740 const Index block_cols =
cols < 128 ?
cols : (lhsStride *
sizeof(bfloat16) < 16000 ? 16 : 8);
748 for (
Index j2 = 0; j2 <
cols; j2 += block_cols)
752 LhsMapper lhs2 = lhs.getSubMapper(0, j2);
753 if (rhs.stride() == 1) {
754 LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
755 calcVSXVecColLoops<LhsMapper, LinearMapper, true>(jend - j2,
rows, lhs2, rhs3, pAlpha, result);
757 RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
758 calcVSXVecColLoops<LhsMapper, RhsMapper, false>(jend - j2,
rows, lhs2, rhs3, pAlpha, result);
765 template<Index num_acc, Index size>
768 constexpr
Index extra = num_acc & 3;
770 for(
Index k = 0; k < num_acc; k += 4) {
772 d0 =
pmadd(acc[k + 0][0], pAlpha, d0);
774 if (num_acc > (k + 3)) {
780 memcpy((
void *)(result + k), (
void *)(&d0),
sizeof(
float) * extra);
786 template<Index num_acc>
789 if (num_acc > (k + 1)) {
790 acc[k][1] = vec_mergel(acc[k + 0][0], acc[k + 1][0]);
791 acc[k][0] = vec_mergeh(acc[k + 0][0], acc[k + 1][0]);
792 acc[k][0] = acc[k][0] + acc[k][1];
793 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
795 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
797 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
799 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
804 template<Index num_acc>
807 for(
Index k = 0; k < num_acc; k += 4) {
808 preduxVecResults2VSX<num_acc>(acc, k + 0);
809 if (num_acc > (k + 2)) {
810 preduxVecResults2VSX<num_acc>(acc, k + 2);
811 #ifdef EIGEN_VECTORIZE_VSX
812 acc[k + 0][0] =
reinterpret_cast<Packet4f>(vec_mergeh(
reinterpret_cast<Packet2ul>(acc[k + 0][0]),
reinterpret_cast<Packet2ul>(acc[k + 2][0])));
832 template<Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extra>
839 b1 = rhs.template loadPacketPartial<Packet8bf>(
j, extra_cols);
844 b1 = rhs.template loadPacket<Packet8bf>(
j);
849 const LhsMapper lhs2 = lhs.getSubMapper(0,
j);
850 for(
Index k = 0; k < num_acc; k++) {
852 a1 = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
857 a1 = lhs2.template loadPacket<Packet8bf>(k, 0);
863 multVecVSX<num_acc, false>(acc, a0, b0);
866 template<Index num_acc,
typename LhsMapper,
typename RhsMapper>
870 for(;
j + 8 <=
cols;
j += 8){
871 multVSXVecLoop<num_acc, LhsMapper, RhsMapper, false>(acc, lhs, rhs,
j, extra_cols);
875 multVSXVecLoop<num_acc, LhsMapper, RhsMapper, true>(acc, lhs, rhs,
j, extra_cols);
879 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper>
888 zeroAccumulators<num_acc, 2>(acc);
890 const LhsMapper lhs2 = lhs.getSubMapper(
row, 0);
891 vecVSXLoop<num_acc, LhsMapper, RhsMapper>(
cols, lhs2, rhs, acc, extra_cols);
893 addResultsVSX<num_acc>(acc);
895 preduxVecResultsVSX<num_acc>(acc);
897 outputVecResults<num_acc, 2>(acc, result, pAlpha);
900 }
while(multiIters && (num_acc <=
rows - (
row += num_acc)));
903 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper>
907 colVSXVecLoopBody<num_acc, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
911 template<
typename LhsMapper,
typename RhsMapper>
916 colVSXVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
919 colVSXVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
922 colVSXVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
925 colVSXVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
928 colVSXVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
931 colVSXVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
934 colVSXVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
939 template<
typename LhsMapper,
typename RhsMapper>
944 colVSXVecLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
947 colVSXVecLoopBodyExtra<LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
950 template<
typename LhsMapper,
typename RhsMapper>
953 const LhsMapper& alhs,
954 const RhsMapper& rhs,
958 typedef typename RhsMapper::LinearMapper LinearMapper;
963 LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
974 convertArrayPointerBF16toF32<true>(result, 1,
rows,
res, resIncr);
976 calcVSXVecLoops<LhsMapper, LinearMapper>(
cols,
rows, lhs, rhs2, pAlpha, result);
980 convertArrayPointerF32toBF16VSX<true>(result,
rows,
res, resIncr);
984 #undef MAX_BFLOAT16_VEC_ACC_VSX
986 const Packet16uc p16uc_COMPLEX32_XORFLIP = { 0x44,0x55,0x66,0x77, 0x00,0x11,0x22,0x33, 0xcc,0xdd,0xee,0xff, 0x88,0x99,0xaa,0xbb };
987 const Packet16uc p16uc_COMPLEX64_XORFLIP = { 0x88,0x99,0xaa,0xbb, 0xcc,0xdd,0xee,0xff, 0x00,0x11,0x22,0x33, 0x44,0x55,0x66,0x77 };
990 const Packet16uc p16uc_COMPLEX32_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00 };
991 const Packet16uc p16uc_COMPLEX64_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
992 const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = { 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
993 const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = { 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
994 const Packet16uc p16uc_COMPLEX32_NEGATE = { 0x80,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x80,0x00,0x00,0x00 };
995 const Packet16uc p16uc_COMPLEX64_NEGATE = { 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x80,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
997 const Packet16uc p16uc_COMPLEX32_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80 };
998 const Packet16uc p16uc_COMPLEX64_CONJ_XOR = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80 };
999 const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = { 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00 };
1000 const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00 };
1001 const Packet16uc p16uc_COMPLEX32_NEGATE = { 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x80 };
1002 const Packet16uc p16uc_COMPLEX64_NEGATE = { 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x80 };
1006 #define COMPLEX_DELTA 0
1008 #define COMPLEX_DELTA 2
1022 #ifdef __POWER8_VECTOR__
1033 #if defined(_ARCH_PWR8) && (!EIGEN_COMP_LLVM || __clang_major__ >= 12)
1034 #define PERMXOR_GOOD
1078 #ifdef __POWER8_VECTOR__
1079 return Packet2cf(vec_neg(
a.v));
1087 #ifdef __POWER8_VECTOR__
1088 return Packet1cd(vec_neg(
a.v));
1121 #ifdef EIGEN_VECTORIZE_VSX
1122 return Packet1cd(__builtin_vsx_xxpermdi(
a.v,
a.v, 2));
1132 #ifdef EIGEN_VECTORIZE_VSX
1134 __asm__(
"lxsdx %x0,%y1" :
"=wa" (t) :
"Z" (*src));
1136 *
reinterpret_cast<std::complex<float>*
>(
reinterpret_cast<float*
>(&t) +
COMPLEX_DELTA) = *src;
1142 template<
typename RhsScalar>
1146 __asm__(
"lxvwsx %x0,%y1" :
"=wa" (r) :
"Z" (*(
reinterpret_cast<float*
>(src) + 0)));
1147 __asm__(
"lxvwsx %x0,%y1" :
"=wa" (
i) :
"Z" (*(
reinterpret_cast<float*
>(src) + 1)));
1155 template<
typename RhsScalar>
1158 #ifdef EIGEN_VECTORIZE_VSX
1159 __asm__(
"lxvdsx %x0,%y1" :
"=wa" (r) :
"Z" (*(
reinterpret_cast<double*
>(src) + 0)));
1160 __asm__(
"lxvdsx %x0,%y1" :
"=wa" (
i) :
"Z" (*(
reinterpret_cast<double*
>(src) + 1)));
1163 r = vec_splat(t, 0);
1164 i = vec_splat(t, 1);
1168 #ifndef __POWER8_VECTOR__
1169 const Packet16uc p16uc_MERGEE = { 0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13, 0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B };
1171 const Packet16uc p16uc_MERGEO = { 0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17, 0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F };
1175 template<
typename RhsScalar>
1179 #ifdef __POWER8_VECTOR__
1180 r = vec_mergee(t, t);
1181 i = vec_mergeo(t, t);
1188 template<
typename RhsScalar>
1197 #ifdef EIGEN_VECTORIZE_VSX
1199 __asm__(
"lxvdsx %x0,%y1" :
"=wa" (ret) :
"Z" (*(
reinterpret_cast<double*
>(src) + 0)));
1223 template<
typename ResPacket>
1235 template<
typename ResPacket>
1242 template<
typename ResPacket>
1248 template<
typename ResPacket>
1301 return vec_mergeh(ret, ret);
1320 template<
typename ResPacket>
1331 template<
typename ResPacket>
1350 template<
typename Scalar,
typename ResScalar>
1353 return (which) ? ((
conj) ? -alpha.real() : alpha.real()) : ((
conj) ? -alpha.imag() : alpha.imag());
1357 template<
typename Scalar,
typename ResScalar,
typename ResPacket,
int which>
1361 ret.v[
COMPLEX_DELTA + 0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
1362 ret.v[
COMPLEX_DELTA + 1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
1368 template<
typename Scalar,
typename ResScalar,
typename ResPacket,
int which>
1372 ret.v[0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
1373 ret.v[1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
1378 template<
typename Packet>
1397 template<
typename Packet,
typename LhsPacket,
typename RhsPacket>
1402 return pset_zero<Packet>();
1410 template<
typename PResPacket,
typename ResPacket,
typename ResScalar,
typename Scalar>
1414 separate.
r = pset1_complex<Scalar, ResScalar, ResPacket, 0x3>(alpha);
1415 separate.
i = pset1_complex<Scalar, ResScalar, ResPacket, 0x0>(alpha);
1424 template<
typename ScalarPacket,
typename AlphaData>
1427 return pmadd(c2, b0.separate.i.v,
pmadd(c0, b0.separate.r.v, c4));
1431 template<
typename Scalar,
typename ScalarPacket,
typename PResPacket,
typename ResPacket,
typename ResScalar,
typename AlphaData>
1436 ScalarPacket c4 = ploadu<ScalarPacket>(
reinterpret_cast<Scalar*
>(
res));
1437 ScalarPacket c3 = pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0);
1440 ScalarPacket c4 = pload_complex<ResPacket>(
res);
1441 PResPacket c3 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1446 template<
typename ScalarPacket,
typename PResPacket,
typename ResPacket,
typename ResScalar,
typename AlphaData, Index ResPacketSize, Index iter2>
1451 #if !defined(_ARCH_PWR10)
1452 ScalarPacket c4 = pload_complex<ResPacket>(
res + (iter2 * ResPacketSize));
1453 ScalarPacket c5 = pload_complex<ResPacket>(
res + ((iter2 + 1) * ResPacketSize));
1454 PResPacket c6 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1455 PResPacket c7 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c5, b0));
1457 pstoreu(
res + ((iter2 + 1) * ResPacketSize), c7);
1459 __vector_pair
a = *
reinterpret_cast<__vector_pair *
>(
res + (iter2 * ResPacketSize));
1462 __builtin_vsx_disassemble_pair(
reinterpret_cast<void*
>(c6), &
a);
1463 c6[0] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c6[0].v, b0));
1464 c6[1] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c6[1].v, b0));
1468 __asm__ (
"xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" :
"+&d" (
a) :
"wa" (b0.separate.r.v),
"wa" (c0.v),
"wa" (c1.v));
1469 __asm__ (
"xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" :
"+&d" (
a) :
"wa" (b0.separate.i.v),
"wa" (c2.v),
"wa" (c3.v));
1471 __asm__ (
"xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" :
"+&d" (
a) :
"wa" (b0.separate.r.v),
"wa" (c0.v),
"wa" (c1.v));
1472 __asm__ (
"xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" :
"+&d" (
a) :
"wa" (b0.separate.i.v),
"wa" (c2.v),
"wa" (c3.v));
1475 *
reinterpret_cast<__vector_pair *
>(
res + (iter2 * ResPacketSize)) =
a;
1480 template<
typename Scalar,
typename LhsScalar,
typename LhsMapper,
typename LhsPacket>
1483 if (
sizeof(Scalar) ==
sizeof(LhsScalar)) {
1484 const LhsScalar& src = lhs(
i + 0,
j);
1487 return lhs.template load<LhsPacket, Unaligned>(
i + 0,
j);
1491 template<
typename ComplexPacket,
typename RealPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1494 if (ConjugateLhs && ConjugateRhs) {
1495 return vec_madd(
a,
pconj2(ComplexPacket(
b)).
v,
c);
1497 else if (Negate && !ConjugateLhs && ConjugateRhs) {
1498 return vec_nmsub(
a,
b,
c);
1501 return vec_madd(
a,
b,
c);
1506 template<
typename ComplexPacket,
typename RealPacket,
bool Conjugate>
1510 return vec_madd(
a,
pconj2(ComplexPacket(
b)).
v,
c);
1513 return vec_madd(
a,
b,
c);
1517 template<
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1520 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
1523 b0 = pset1<RhsPacket>(*
b);
1526 b0 = ploadu<RhsPacket>(
b);
1528 c0 = pcj.pmadd(a0, b0, c0);
1532 template<
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1535 ScalarPacket br, bi;
1537 pload_realimag<RhsScalar>(
b, br, bi);
1540 pload_realimag_row<RhsScalar>(
b, br, bi);
1542 if (ConjugateLhs && !ConjugateRhs) a0 =
pconj2(a0);
1544 ScalarPacket cr = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, false>(a0.v, br, c0.v);
1545 ScalarPacket ci = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, true>(a1.v, bi, c1.v);
1547 c0 = PResPacket(cr);
1551 template<
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1561 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateRhs>(a0, b0, c0.v);
1562 c0 = PResPacket(cri);
1566 template<
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1569 ScalarPacket a1 = pload_complex<ResPacket>(&a0);
1575 b0 = pload_real_row<ResPacket>(
b);
1577 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateLhs>(a1, b0, c0.v);
1578 c0 = PResPacket(cri);
1581 #define GEMV_MULT_COMPLEX_COMPLEX(LhsType, RhsType, ResType) \
1582 template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1583 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, ResType& c1) \
1585 gemv_mult_complex_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0, c1); \
1591 #define GEMV_MULT_REAL_COMPLEX(LhsType, RhsType, ResType) \
1592 template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1593 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, RhsType&) \
1595 gemv_mult_real_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1603 #define GEMV_MULT_COMPLEX_REAL(LhsType, RhsType, ResType1, ResType2) \
1604 template<typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1605 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType1& c0, ResType2&) \
1607 gemv_mult_complex_real<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1617 template<
typename T>
1634 template<
typename T>
1642 return Packet2cf(
a);
1647 return Packet1cd(
a);
1651 template<
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename ResPacket>
1654 a = SLhsPacket(pload_complex<ResPacket>(&
a));
1657 template<
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename ResPacket>
1664 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1667 if (NegativeAccumulate)
1669 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
1672 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
1677 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1680 if (NegativeAccumulate)
1682 __builtin_mma_xvf64gernp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
1685 __builtin_mma_xvf64gerpp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
1689 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1696 template<
typename RealPacket,
typename LhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1699 if (ConjugateLhs && ConjugateRhs) {
1700 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1701 return pger_vecMMA<RealPacket, RealPacket, false>(
c, b2,
a.v);
1703 else if (Negate && !ConjugateLhs && ConjugateRhs) {
1704 return pger_vecMMA<RealPacket, RealPacket, true>(
c,
b,
a.v);
1707 return pger_vecMMA<RealPacket, RealPacket, false>(
c,
b,
a.v);
1711 template<
typename RealPacket,
typename LhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1714 if (ConjugateLhs && ConjugateRhs) {
1715 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1716 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a, b2);
1718 else if (Negate && !ConjugateLhs && ConjugateRhs) {
1719 return pger_vecMMA<RealPacket, __vector_pair, true>(
c,
a,
b);
1722 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a,
b);
1727 template<
typename RealPacket,
typename LhsPacket,
bool Conjugate,
int StorageOrder>
1730 RealPacket a2 = convertReal(
a);
1732 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1734 return pger_vecMMA<RealPacket, RealPacket, false>(
c, b2, a2);
1736 return pger_vecMMA<RealPacket, RealPacket, false>(
c, a2, b2);
1741 return pger_vecMMA<RealPacket, RealPacket, false>(
c,
b, a2);
1743 return pger_vecMMA<RealPacket, RealPacket, false>(
c, a2,
b);
1749 template<
typename RealPacket,
typename LhsPacket,
bool Conjugate,
int StorageOrder>
1753 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1754 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a, b2);
1757 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a,
b);
1762 template<
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1763 EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex_MMA(SLhsPacket& a0, RhsScalar*
b, __vector_quad* c0)
1771 pmadd_complex_complex_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ConjugateRhs, false>(a0, b0, c0);
1775 template<
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1776 EIGEN_ALWAYS_INLINE void gemv_mult_complex_real_MMA(SLhsPacket& a0, RhsScalar*
b, __vector_quad* c0)
1778 pload_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, ResPacket>(a0);
1784 b0 = pload_real_row<ResPacket>(
b);
1786 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ColMajor>(a0, b0, c0);
1790 template<
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1791 EIGEN_ALWAYS_INLINE void gemv_mult_real_complex_MMA(SLhsPacket& a0, RhsScalar*
b, __vector_quad* c0)
1800 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateRhs, (
sizeof(RhsScalar) ==
sizeof(std::complex<float>)) ? StorageOrder :
ColMajor>(a0, b0, c0);
1803 #define GEMV_MULT_COMPLEX_COMPLEX_MMA(LhsType, RhsType) \
1804 template<typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1805 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) \
1807 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1810 GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet2cf, std::complex<float>)
1811 GEMV_MULT_COMPLEX_COMPLEX_MMA(__vector_pair, std::complex<float>)
1812 GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet1cd, std::complex<double>)
1815 template<
typename ScalarPacket,
typename LhsScalar,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename RhsPacket,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1816 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(__vector_pair& a0, std::complex<double>*
b, __vector_quad* c0)
1818 if (
sizeof(LhsScalar) == 16) {
1819 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0,
b, c0);
1822 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0,
b, c0);
1826 #define GEMV_MULT_REAL_COMPLEX_MMA(LhsType, RhsType) \
1827 template<typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1828 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) \
1830 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1833 GEMV_MULT_REAL_COMPLEX_MMA(
Packet4f, std::complex<float>)
1834 GEMV_MULT_REAL_COMPLEX_MMA(
Packet2d, std::complex<double>)
1836 #define GEMV_MULT_COMPLEX_REAL_MMA(LhsType, RhsType) \
1837 template<typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1838 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) \
1840 gemv_mult_complex_real_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, StorageOrder>(a0, b, c0); \
1843 GEMV_MULT_COMPLEX_REAL_MMA(Packet2cf,
float)
1844 GEMV_MULT_COMPLEX_REAL_MMA(Packet1cd,
double)
1845 GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair,
float)
1846 GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair,
double)
1849 template <
typename Scalar,
typename ScalarPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
1850 EIGEN_ALWAYS_INLINE void disassembleResults2(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0)
1852 __builtin_mma_disassemble_acc(&result0.packet, c0);
1853 if (
sizeof(LhsPacket) == 16) {
1854 if (
sizeof(RhsPacket) == 16) {
1855 ScalarPacket tmp0, tmp2;
1856 tmp2 = vec_mergeh(result0.packet[2], result0.packet[3]);
1857 tmp0 = vec_mergeh(result0.packet[0], result0.packet[1]);
1858 result0.packet[3] = vec_mergel(result0.packet[3], result0.packet[2]);
1859 result0.packet[1] = vec_mergel(result0.packet[1], result0.packet[0]);
1860 result0.packet[2] = tmp2;
1861 result0.packet[0] = tmp0;
1864 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
1865 result0.packet[2] =
pconj2(convertComplex(result0.packet[2])).v;
1866 }
else if (ConjugateRhs) {
1867 result0.packet[1] =
pconj2(convertComplex(result0.packet[1])).v;
1868 result0.packet[3] =
pconj2(convertComplex(result0.packet[3])).v;
1870 result0.packet[1] =
pconjinv(convertComplex(result0.packet[1])).v;
1871 result0.packet[3] =
pconjinv(convertComplex(result0.packet[3])).v;
1873 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1874 result0.packet[2] = vec_add(result0.packet[2], result0.packet[3]);
1876 result0.packet[0][1] = result0.packet[1][1];
1877 result0.packet[2][1] = result0.packet[3][1];
1882 template <
typename Scalar,
typename ScalarPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
1883 EIGEN_ALWAYS_INLINE void disassembleResults4(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0)
1885 __builtin_mma_disassemble_acc(&result0.packet, c0);
1888 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
1889 result0.packet[1] =
pcplxflip2(convertComplex(result0.packet[1])).v;
1892 result0.packet[1] =
pcplxconjflip(convertComplex(result0.packet[1])).v;
1894 result0.packet[1] =
pcplxflipconj(convertComplex(result0.packet[1])).v;
1897 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1898 }
else if (
sizeof(LhsPacket) ==
sizeof(std::complex<float>)) {
1900 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
1903 result0.packet[0] = vec_mergee(result0.packet[0], result0.packet[1]);
1907 template <
typename Scalar,
typename ScalarPacket,
int ResPacketSize,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
1908 EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0)
1911 disassembleResults2<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1913 disassembleResults4<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1918 #define GEMV_GETN_COMPLEX(N) (((N) * ResPacketSize) >> 1)
1920 #define GEMV_LOADPACKET_COL_COMPLEX(iter) \
1921 loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + ((iter) * ResPacketSize), j)
1923 #define GEMV_LOADPACKET_COL_COMPLEX_DATA(iter) \
1924 convertReal(GEMV_LOADPACKET_COL_COMPLEX(iter))
1927 #define GEMV_INIT_COL_COMPLEX_MMA(iter, N) \
1928 if (GEMV_GETN_COMPLEX(N) > iter) { \
1929 __builtin_mma_xxsetaccz(&e0##iter); \
1933 #define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1934 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1)); \
1935 EIGEN_UNUSED_VARIABLE(f##iter1);
1937 #define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1938 if (sizeof(LhsPacket) == 16) { \
1939 const LhsScalar& src = lhs(i + ((32 * iter1) / sizeof(LhsScalar)), j); \
1940 a##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src)); \
1941 EIGEN_UNUSED_VARIABLE(f##iter1); \
1943 f##iter1 = lhs.template load<PLhsPacket, Unaligned>(i + ((iter2) * ResPacketSize), j); \
1944 GEMV_BUILDPAIR_MMA(a##iter1, vec_splat(convertReal(f##iter1), 0), vec_splat(convertReal(f##iter1), 1)); \
1948 #define GEMV_LOAD1_COL_COMPLEX_MMA(iter, N) \
1949 if (GEMV_GETN_COMPLEX(N) > iter) { \
1950 if (GEMV_IS_COMPLEX_FLOAT) { \
1951 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1952 EIGEN_UNUSED_VARIABLE(a##iter); \
1954 GEMV_LOADPAIR_COL_COMPLEX_MMA(iter, iter << 1) \
1957 EIGEN_UNUSED_VARIABLE(a##iter); \
1958 EIGEN_UNUSED_VARIABLE(f##iter); \
1961 #define GEMV_WORK1_COL_COMPLEX_MMA(iter, N) \
1962 if (GEMV_GETN_COMPLEX(N) > iter) { \
1963 if (GEMV_IS_COMPLEX_FLOAT) { \
1964 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(f##iter, b, &e0##iter); \
1966 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(a##iter, b, &e0##iter); \
1970 #define GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter1, iter2) \
1971 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1));
1973 #define GEMV_LOAD2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1974 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1975 if (GEMV_IS_COMPLEX_FLOAT) { \
1976 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2); \
1977 EIGEN_UNUSED_VARIABLE(a##iter3) \
1979 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2 << 1); \
1980 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter3, iter3 << 1); \
1983 EIGEN_UNUSED_VARIABLE(a##iter2); \
1984 EIGEN_UNUSED_VARIABLE(a##iter3); \
1986 EIGEN_UNUSED_VARIABLE(f##iter2); \
1987 EIGEN_UNUSED_VARIABLE(f##iter3);
1989 #define GEMV_WORK2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1990 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1991 if (GEMV_IS_COMPLEX_FLOAT) { \
1993 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(g), &a##iter2); \
1994 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(g[0], b, &e0##iter2); \
1995 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(g[1], b, &e0##iter3); \
1997 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(a##iter2, b, &e0##iter2); \
1998 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(a##iter3, b, &e0##iter3); \
2003 #define GEMV_LOAD_COL_COMPLEX_MMA(N) \
2004 if (GEMV_GETN_COMPLEX(N) > 1) { \
2005 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_COMPLEX_MMA, (N >> 1)) \
2007 GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N) \
2010 #define GEMV_WORK_COL_COMPLEX_MMA(N) \
2011 if (GEMV_GETN_COMPLEX(N) > 1) { \
2012 GEMV_UNROLL_HALF(GEMV_WORK2_COL_COMPLEX_MMA, (N >> 1)) \
2014 GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N) \
2017 #define GEMV_LOAD_COL_COMPLEX_MMA(N) \
2018 GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N)
2020 #define GEMV_WORK_COL_COMPLEX_MMA(N) \
2021 GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N)
2024 #define GEMV_DISASSEMBLE_COMPLEX_MMA(iter) \
2025 disassembleResults<Scalar, ScalarPacket, ResPacketSize, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter, result0##iter);
2027 #define GEMV_STORE_COL_COMPLEX_MMA(iter, N) \
2028 if (GEMV_GETN_COMPLEX(N) > iter) { \
2029 GEMV_DISASSEMBLE_COMPLEX_MMA(iter); \
2030 c0##iter = PResPacket(result0##iter.packet[0]); \
2031 if (GEMV_IS_COMPLEX_FLOAT) { \
2032 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
2034 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + ((iter << 1) * ResPacketSize)); \
2035 c0##iter = PResPacket(result0##iter.packet[2]); \
2036 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + (((iter << 1) + 1) * ResPacketSize)); \
2040 #define GEMV_STORE2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
2041 if (GEMV_GETN_COMPLEX(N) > iter1) { \
2042 GEMV_DISASSEMBLE_COMPLEX_MMA(iter2); \
2043 GEMV_DISASSEMBLE_COMPLEX_MMA(iter3); \
2044 c0##iter2 = PResPacket(result0##iter2.packet[0]); \
2045 if (GEMV_IS_COMPLEX_FLOAT) { \
2046 c0##iter3 = PResPacket(result0##iter3.packet[0]); \
2047 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2>(c0##iter2, c0##iter3, alpha_data, res + i); \
2049 c0##iter3 = PResPacket(result0##iter2.packet[2]); \
2050 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
2051 c0##iter2 = PResPacket(result0##iter3.packet[0]); \
2052 c0##iter3 = PResPacket(result0##iter3.packet[2]); \
2053 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter3 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
2057 #define GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
2058 GEMV_UNROLL(GEMV_INIT_COL_COMPLEX_MMA, N) \
2061 const RhsScalar& b1 = rhs2(j, 0); \
2062 RhsScalar* b = const_cast<RhsScalar *>(&b1); \
2063 GEMV_UNROLL(GEMV_PREFETCH, N) \
2064 GEMV_LOAD_COL_COMPLEX_MMA(N) \
2065 GEMV_WORK_COL_COMPLEX_MMA(N) \
2066 } while (++j < jend); \
2067 if (GEMV_GETN(N) <= 2) { \
2068 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX_MMA, N) \
2070 GEMV_UNROLL_HALF(GEMV_STORE2_COL_COMPLEX_MMA, (N >> 1)) \
2072 i += (ResPacketSize * N);
2075 #define GEMV_INIT_COMPLEX(iter, N) \
2077 c0##iter = pset_zero<PResPacket>(); \
2078 c1##iter = pset_init<ResPacket, LhsPacket, RhsPacket>(c1##iter); \
2080 EIGEN_UNUSED_VARIABLE(c0##iter); \
2081 EIGEN_UNUSED_VARIABLE(c1##iter); \
2084 #define GEMV_WORK_COL_COMPLEX(iter, N) \
2086 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
2087 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, ColMajor>(f##iter, b, c0##iter, c1##iter); \
2089 EIGEN_UNUSED_VARIABLE(f##iter); \
2092 #define GEMV_STORE_COL_COMPLEX(iter, N) \
2094 if (GEMV_IS_COMPLEX_COMPLEX) { \
2095 c0##iter = padd(c0##iter, c1##iter); \
2097 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>(c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
2101 #define GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2102 GEMV_UNROLL(GEMV_INIT_COMPLEX, N) \
2105 const RhsScalar& b1 = rhs2(j, 0); \
2106 RhsScalar* b = const_cast<RhsScalar *>(&b1); \
2107 GEMV_UNROLL(GEMV_PREFETCH, N) \
2108 GEMV_UNROLL(GEMV_WORK_COL_COMPLEX, N) \
2109 } while (++j < jend); \
2110 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX, N) \
2111 i += (ResPacketSize * N);
2113 #if defined(USE_GEMV_MMA) && (EIGEN_COMP_LLVM || defined(USE_SLOWER_GEMV_MMA))
2114 #define USE_GEMV_COL_COMPLEX_MMA
2117 #ifdef USE_GEMV_COL_COMPLEX_MMA
2118 #define GEMV_PROCESS_COL_COMPLEX(N) \
2119 GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N)
2121 #if defined(USE_GEMV_MMA) && (__GNUC__ > 10)
2122 #define GEMV_PROCESS_COL_COMPLEX(N) \
2123 if (sizeof(Scalar) != sizeof(LhsPacket)) { \
2124 GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
2126 GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2129 #define GEMV_PROCESS_COL_COMPLEX(N) \
2130 GEMV_PROCESS_COL_COMPLEX_ONE(N)
2134 template<
typename Scalar,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
bool LhsIsReal,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
bool RhsIsReal,
typename ResScalar>
2137 const LhsMapper& alhs,
2138 const RhsMapper& rhs,
2142 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2144 typedef typename Traits::LhsPacket LhsPacket;
2145 typedef typename Traits::RhsPacket RhsPacket;
2146 typedef typename Traits::ResPacket ResPacket;
2148 typedef typename packet_traits<Scalar>::type ScalarPacket;
2149 typedef typename packet_traits<LhsScalar>::type PLhsPacket;
2150 typedef typename packet_traits<ResScalar>::type PResPacket;
2151 typedef gemv_traits<ResPacket, ResPacket> PTraits;
2158 LhsMapper lhs(alhs);
2159 RhsMapper rhs2(rhs);
2161 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2163 const Index lhsStride = lhs.stride();
2167 ResPacketSize = PTraits::ResPacketSize,
2168 LhsPacketSize = PTraits::LhsPacketSize,
2169 RhsPacketSize = PTraits::RhsPacketSize,
2171 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
2172 const Index prefetch_dist = 64 * LhsPacketSize;
2175 #ifndef GCC_ONE_VECTORPAIR_BUG
2176 const Index n8 =
rows - 8 * ResPacketSize + 1;
2177 const Index n4 =
rows - 4 * ResPacketSize + 1;
2178 const Index n2 =
rows - 2 * ResPacketSize + 1;
2180 const Index n1 =
rows - 1 * ResPacketSize + 1;
2183 const Index block_cols =
cols < 128 ?
cols : (lhsStride *
sizeof(LhsScalar) < 16000 ? 16 : 8);
2186 AlphaData alpha_data(alpha);
2188 for (
Index j2 = 0; j2 <
cols; j2 += block_cols)
2192 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2193 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2194 PLhsPacket f0, f1, f2, f3, f4, f5, f6, f7;
2196 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2197 __vector_pair a0, a1, a2, a3, a4, a5, a6, a7;
2198 PacketBlock<ScalarPacket, 4> result00, result01, result02, result03, result04, result05, result06, result07;
2200 GEMV_UNUSED(8, result0)
2203 #if !defined(GCC_ONE_VECTORPAIR_BUG) && defined(USE_GEMV_COL_COMPLEX_MMA)
2207 #ifndef GCC_ONE_VECTORPAIR_BUG
2234 d0 += cj.pmul(lhs(
i,
j), rhs2(
j, 0));
2235 }
while (++
j < jend);
2236 res[
i] += alpha * d0;
2246 static Packet16uc p16uc_ELEMENT_3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
2249 template<
typename ResScalar,
typename ResPacket>
2252 PacketBlock<ResPacket, 4> result0, result1;
2253 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2254 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2255 result0.packet[0] = vec_mergeh(result0.packet[0], result1.packet[0]);
2256 result0.packet[1] = vec_mergeo(result0.packet[1], result1.packet[1]);
2257 result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
2258 result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
2259 result0.packet[0] = vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
2266 PacketBlock<Packet2d, 4> result0, result1;
2267 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2268 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2269 result0.packet[0] = vec_add(vec_mergeh(result0.packet[0], result1.packet[0]), vec_mergel(result0.packet[1], result1.packet[1]));
2274 template<
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2278 result0.packet[0] =
reinterpret_cast<Packet4f>(vec_mergeh(
reinterpret_cast<Packet2d>(result0.packet[0]),
reinterpret_cast<Packet2d>(result1.packet[0])));
2279 result0.packet[2] =
reinterpret_cast<Packet4f>(vec_mergel(
reinterpret_cast<Packet2d>(result0.packet[2]),
reinterpret_cast<Packet2d>(result1.packet[2])));
2280 result0.packet[0] = vec_add(result0.packet[0], result0.packet[2]);
2282 result0.packet[1] =
reinterpret_cast<Packet4f>(vec_mergeh(
reinterpret_cast<Packet2d>(result0.packet[1]),
reinterpret_cast<Packet2d>(result1.packet[1])));
2283 result0.packet[3] =
reinterpret_cast<Packet4f>(vec_mergel(
reinterpret_cast<Packet2d>(result0.packet[3]),
reinterpret_cast<Packet2d>(result1.packet[3])));
2284 result0.packet[1] = vec_add(result0.packet[1], result0.packet[3]);
2286 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
2287 result0.packet[1] =
pcplxflip2(convertComplex(result0.packet[1])).v;
2288 }
else if (ConjugateRhs) {
2289 result0.packet[1] =
pcplxconjflip(convertComplex(result0.packet[1])).v;
2291 result0.packet[1] =
pcplxflipconj(convertComplex(result0.packet[1])).v;
2293 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
2295 if (ConjugateLhs && (
sizeof(LhsPacket) ==
sizeof(std::complex<float>))) {
2296 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
2299 cc0.
scalar[0].real(result0.packet[0][0]);
2300 cc0.
scalar[0].imag(result0.packet[0][1]);
2301 cc0.
scalar[1].real(result0.packet[0][2]);
2302 cc0.
scalar[1].imag(result0.packet[0][3]);
2306 template<
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2315 template<
typename ResScalar,
typename ResPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2318 PacketBlock<ResPacket, 4> result0, result1;
2319 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2320 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2321 return addComplexResults<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(result0, result1);
2324 template<
typename ResScalar,
typename ResPacket>
2327 PacketBlock<ResPacket, 4> result0;
2328 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2329 result0.packet[0] = vec_add(vec_mergeh(result0.packet[0], result0.packet[2]), vec_mergel(result0.packet[1], result0.packet[3]));
2333 template<
typename ResScalar,
typename ResPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2337 PacketBlock<ResPacket, 4> result0;
2338 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2341 result0.packet[1] =
pconjinv(convertComplex(result0.packet[1])).v;
2342 result0.packet[3] =
pconjinv(convertComplex(result0.packet[3])).v;
2343 }
else if (ConjugateRhs) {
2344 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
2345 result0.packet[2] =
pconj2(convertComplex(result0.packet[2])).v;
2347 result0.packet[1] =
pconj2(convertComplex(result0.packet[1])).v;
2348 result0.packet[3] =
pconj2(convertComplex(result0.packet[3])).v;
2350 result0.packet[0] = vec_add(result0.packet[0], __builtin_vsx_xxpermdi(result0.packet[1], result0.packet[1], 2));
2351 result0.packet[2] = vec_add(result0.packet[2], __builtin_vsx_xxpermdi(result0.packet[3], result0.packet[3], 2));
2353 result0.packet[0] = __builtin_vsx_xxpermdi(result0.packet[0], result0.packet[1], 1);
2354 result0.packet[2] = __builtin_vsx_xxpermdi(result0.packet[2], result0.packet[3], 1);
2356 cc0.
scalar[0].real(result0.packet[0][0]);
2357 cc0.
scalar[0].imag(result0.packet[0][1]);
2358 cc0.
scalar[1].real(result0.packet[2][0]);
2359 cc0.
scalar[1].imag(result0.packet[2][1]);
2364 template<
typename ResScalar,
typename ResPacket>
2373 template<
typename ResScalar,
typename ResPacket>
2376 return predux_real<ResScalar, ResPacket>(
a,
b);
2379 #define GEMV_UNROLL_ROW(func, N) \
2380 func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
2382 #define GEMV_UNROLL_ROW_HALF(func, N) \
2383 func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
2385 #define GEMV_LOADPACKET_ROW(iter) \
2386 lhs.template load<LhsPacket, Unaligned>(i + (iter), j)
2389 #define GEMV_UNROLL3_ROW(func, N, which) \
2390 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) \
2391 func(4, N, which) func(5, N, which) func(6, N, which) func(7, N, which)
2393 #define GEMV_UNUSED_ROW(N, which) \
2394 GEMV_UNROLL3_ROW(GEMV_UNUSED_VAR, N, which)
2396 #define GEMV_INIT_ROW(iter, N) \
2397 if (GEMV_GETN(N) > iter) { \
2398 __builtin_mma_xxsetaccz(&c##iter); \
2401 #define GEMV_LOADPAIR_ROW(iter1, iter2) \
2402 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_ROW(iter2), GEMV_LOADPACKET_ROW((iter2) + 1));
2404 #define GEMV_WORK_ROW(iter, N) \
2405 if (GEMV_GETN(N) > iter) { \
2406 if (GEMV_IS_FLOAT) { \
2407 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, a0, GEMV_LOADPACKET_ROW(iter)); \
2409 __vector_pair b##iter; \
2410 GEMV_LOADPAIR_ROW(iter, iter << 1) \
2411 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, b##iter, a0); \
2415 #define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2417 if (GEMV_IS_FLOAT) { \
2418 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter2, &c##iter3); \
2420 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter1); \
2423 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2426 #define GEMV_INIT_ROW(iter, N) \
2428 c##iter = pset1<ResPacket>(ResScalar(0)); \
2430 EIGEN_UNUSED_VARIABLE(c##iter); \
2433 #define GEMV_WORK_ROW(iter, N) \
2435 c##iter = pcj.pmadd(GEMV_LOADPACKET_ROW(iter), a0, c##iter); \
2438 #define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2440 cc##iter1 = predux_real<ResScalar, ResPacket>(c##iter2, c##iter3); \
2442 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2446 #define GEMV_MULT(iter1, iter2, iter3, N) \
2448 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), a0); \
2449 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), a0); \
2452 #define GEMV_STORE_ROW(iter1, iter2, iter3, N) \
2454 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2455 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2459 #define GEMV_PROCESS_ROW(N) \
2460 for (; i < n##N; i += N) { \
2461 GEMV_UNROLL_ROW(GEMV_INIT_ROW, N) \
2463 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2464 RhsPacket a0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2465 GEMV_UNROLL_ROW(GEMV_WORK_ROW, N) \
2467 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX2, (N >> 1)) \
2468 for (; j < cols; ++j) { \
2469 RhsScalar a0 = rhs2(j); \
2470 GEMV_UNROLL_ROW_HALF(GEMV_MULT, (N >> 1)) \
2472 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
2475 template<
typename LhsScalar,
typename LhsMapper,
typename RhsScalar,
typename RhsMapper,
typename ResScalar>
2478 const LhsMapper& alhs,
2479 const RhsMapper& rhs,
2483 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2485 typedef typename Traits::LhsPacket LhsPacket;
2486 typedef typename Traits::RhsPacket RhsPacket;
2487 typedef typename Traits::ResPacket ResPacket;
2491 LhsMapper lhs(alhs);
2492 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2495 conj_helper<LhsScalar, RhsScalar, false, false> cj;
2496 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
2500 #ifndef GCC_ONE_VECTORPAIR_BUG
2501 const Index n8 = lhs.stride() *
sizeof(LhsScalar) > 32000 ? (
rows - 7) : (
rows - 7);
2509 ResPacketSize = Traits::ResPacketSize,
2510 LhsPacketSize = Traits::LhsPacketSize,
2511 RhsPacketSize = Traits::RhsPacketSize,
2516 __vector_quad c0, c1, c2, c3, c4, c5, c6, c7;
2517 GEMV_UNUSED_ROW(8,
c)
2519 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
2521 #ifndef GCC_ONE_VECTORPAIR_BUG
2529 ResPacket d0 = pset1<ResPacket>(ResScalar(0));
2531 for (;
j + LhsPacketSize <=
cols;
j += LhsPacketSize)
2533 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(
j);
2535 d0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 0,
j), b0, d0);
2537 ResScalar dd0 =
predux(d0);
2540 dd0 += cj.pmul(lhs(
i,
j), rhs2(
j));
2542 res[
i * resIncr] += alpha * dd0;
2546 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar) \
2547 template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2548 struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
2550 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2552 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2553 Index rows, Index cols, \
2554 const LhsMapper& lhs, \
2555 const RhsMapper& rhs, \
2556 ResScalar* res, Index resIncr, \
2557 ResScalar alpha) { \
2558 gemv_col<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2562 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar) \
2563 template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2564 struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
2566 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2568 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2569 Index rows, Index cols, \
2570 const LhsMapper& lhs, \
2571 const RhsMapper& rhs, \
2572 ResScalar* res, Index resIncr, \
2573 ResScalar alpha) { \
2574 gemv_row<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2584 #define gemv_bf16_col gemvMMA_bfloat16_col
2585 #define gemv_bf16_row gemvMMA_bfloat16_row
2587 #define gemv_bf16_col gemv_bfloat16_col
2588 #define gemv_bf16_row gemv_bfloat16_row
2591 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \
2592 template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2593 struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, ConjugateLhs, bfloat16, RhsMapper, ConjugateRhs, Version> \
2595 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2596 Index rows, Index cols, \
2597 const LhsMapper& lhs, \
2598 const RhsMapper& rhs, \
2599 bfloat16* res, Index resIncr, \
2601 gemv_bf16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2605 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() \
2606 template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2607 struct general_matrix_vector_product<Index, bfloat16, LhsMapper, RowMajor, ConjugateLhs, bfloat16, RhsMapper, ConjugateRhs, Version> \
2609 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2610 Index rows, Index cols, \
2611 const LhsMapper& lhs, \
2612 const RhsMapper& rhs, \
2613 bfloat16* res, Index resIncr, \
2615 gemv_bf16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2622 template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
2629 return predux_complex<ResScalar, PResPacket>(a0, b0);
2632 #define GEMV_LOADPACKET_ROW_COMPLEX(iter) \
2633 loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + (iter), j)
2635 #define GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter) \
2636 convertReal(GEMV_LOADPACKET_ROW_COMPLEX(iter))
2638 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(which, N) \
2640 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2641 const RhsScalar& b1 = rhs2(j); \
2642 RhsScalar* b = const_cast<RhsScalar *>(&b1); \
2643 GEMV_UNROLL_ROW(which, N) \
2646 #define GEMV_PROCESS_END_ROW_COMPLEX(N) \
2647 for (; j < cols; ++j) { \
2648 RhsScalar b0 = rhs2(j); \
2649 GEMV_UNROLL_ROW_HALF(GEMV_MULT_COMPLEX, (N >> 1)) \
2651 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW_COMPLEX, (N >> 1))
2654 #define GEMV_INIT_ROW_COMPLEX_MMA(iter, N) \
2655 if (GEMV_GETN_COMPLEX(N) > iter) { \
2656 __builtin_mma_xxsetaccz(&e0##iter); \
2659 #define GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter1, iter2) \
2660 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter2), GEMV_LOADPACKET_ROW_COMPLEX_DATA((iter2) + 1));
2662 #define GEMV_WORK_ROW_COMPLEX_MMA(iter, N) \
2663 if (GEMV_GETN_COMPLEX(N) > iter) { \
2664 if (GEMV_IS_COMPLEX_FLOAT) { \
2665 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2666 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2668 __vector_pair a##iter; \
2669 GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter, iter << 1) \
2670 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2674 #define GEMV_PREDUX4_COMPLEX_MMA(iter1, iter2, iter3, N) \
2676 if (GEMV_IS_COMPLEX_FLOAT) { \
2677 cc##iter1 = predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter2, &e0##iter3); \
2679 cc##iter1 = predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter1); \
2682 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2685 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2686 GEMV_UNROLL_ROW(GEMV_INIT_ROW_COMPLEX_MMA, N) \
2687 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX_MMA, N)
2689 #define GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N) \
2690 for (; i < n##N; i += N) { \
2691 GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2692 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_MMA, (N >> 1)) \
2693 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2697 #define GEMV_WORK_ROW_COMPLEX(iter, N) \
2699 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2700 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, c0##iter, c1##iter); \
2703 #define GEMV_PREDUX4_COMPLEX(iter1, iter2, iter3, N) \
2705 cc##iter1 = predux_complex<ResScalar, PResPacket, ResPacket, LhsPacket, RhsPacket>(c0##iter2, c0##iter3, c1##iter2, c1##iter3); \
2707 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2710 #define GEMV_MULT_COMPLEX(iter1, iter2, iter3, N) \
2712 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), b0); \
2713 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), b0); \
2716 #define GEMV_STORE_ROW_COMPLEX(iter1, iter2, iter3, N) \
2718 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2719 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2722 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2723 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX, N) \
2724 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX, N)
2727 #define GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2728 for (; i < n##N; i += N) { \
2729 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2730 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX, (N >> 1)) \
2731 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2734 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2735 if (GEMV_IS_COMPLEX_COMPLEX) { \
2736 c0##iter = padd(c0##iter, c1##iter); \
2738 dd0 = predux(c0##iter);
2741 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) \
2742 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N)
2744 #define GEMV_PROCESS_ROW_COMPLEX_ONE(N) \
2745 GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N)
2747 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) \
2748 GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter)
2753 #define GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter) \
2754 lhs.template load<LhsPacket, LhsAlignment>(i + (iter), j)
2756 #define GEMV_INIT_COMPLEX_OLD(iter, N) \
2757 EIGEN_UNUSED_VARIABLE(c0##iter); \
2759 c1##iter = pset_zero<ResPacket>(); \
2761 EIGEN_UNUSED_VARIABLE(c1##iter); \
2764 #define GEMV_WORK_ROW_COMPLEX_OLD(iter, N) \
2766 LhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter); \
2767 c1##iter = pcj.pmadd(a##iter, b0, c1##iter); \
2770 #define GEMV_PREDUX4_COMPLEX_OLD(iter1, iter2, iter3, N) \
2772 cc##iter1.scalar[0] = predux(c1##iter2); \
2773 cc##iter1.scalar[1] = predux(c1##iter3); \
2775 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2778 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2779 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX_OLD, N) \
2781 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2782 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2783 GEMV_UNROLL_ROW(GEMV_WORK_ROW_COMPLEX_OLD, N) \
2786 #define GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2787 for (; i < n##N; i += N) { \
2788 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2789 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_OLD, (N >> 1)) \
2790 GEMV_PROCESS_END_ROW_COMPLEX(N) \
2793 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) \
2794 dd0 = predux(c1##iter);
2797 #define GEMV_PROCESS_ROW_COMPLEX_IS_NEW 1
2799 #define GEMV_PROCESS_ROW_COMPLEX_IS_NEW \
2800 (sizeof(Scalar) == sizeof(float)) || GEMV_IS_COMPLEX_COMPLEX
2803 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) \
2804 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2805 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2807 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2810 #define GEMV_PROCESS_ROW_COMPLEX_ONE(N) \
2811 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2812 GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2814 GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2817 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) \
2818 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2819 GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2821 GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) \
2826 #define GEMV_PROCESS_ROW_COMPLEX(N) \
2827 GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N)
2829 #define GEMV_PROCESS_ROW_COMPLEX(N) \
2830 GEMV_PROCESS_ROW_COMPLEX_ONE(N)
2833 template<
typename Scalar,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
bool LhsIsReal,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
bool RhsIsReal,
typename ResScalar>
2836 const LhsMapper& alhs,
2837 const RhsMapper& rhs,
2841 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2843 typedef typename Traits::LhsPacket LhsPacket;
2844 typedef typename Traits::RhsPacket RhsPacket;
2845 typedef typename Traits::ResPacket ResPacket;
2847 typedef typename packet_traits<Scalar>::type ScalarPacket;
2848 typedef typename packet_traits<LhsScalar>::type PLhsPacket;
2849 typedef typename packet_traits<ResScalar>::type PResPacket;
2850 typedef gemv_traits<ResPacket, ResPacket> PTraits;
2854 LhsMapper lhs(alhs);
2855 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2858 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2859 #if !EIGEN_COMP_LLVM
2860 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
2865 #ifndef GCC_ONE_VECTORPAIR_BUG
2866 const Index n8 = lhs.stride() *
sizeof(LhsScalar) > 32000 ? (
rows - 7) : (
rows - 7);
2874 ResPacketSize = PTraits::ResPacketSize,
2875 LhsPacketSize = PTraits::LhsPacketSize,
2876 RhsPacketSize = PTraits::RhsPacketSize,
2880 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2881 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2883 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2884 GEMV_UNUSED_ROW(8, e0)
2885 GEMV_UNUSED_EXTRA(1, c0)
2886 GEMV_UNUSED_EXTRA(1, c1)
2889 #ifndef GCC_ONE_VECTORPAIR_BUG
2906 dd0 += cj.pmul(lhs(
i,
j), rhs2(
j));
2908 res[
i * resIncr] += alpha * dd0;
2912 #define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar) \
2913 template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2914 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
2916 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2918 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2919 Index rows, Index cols, \
2920 const LhsMapper& lhs, \
2921 const RhsMapper& rhs, \
2922 ResScalar* res, Index resIncr, \
2923 ResScalar alpha) { \
2924 gemv_complex_col<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2928 #define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar) \
2929 template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2930 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
2932 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2934 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
2935 Index rows, Index cols, \
2936 const LhsMapper& lhs, \
2937 const RhsMapper& rhs, \
2938 ResScalar* res, Index resIncr, \
2939 ResScalar alpha) { \
2940 gemv_complex_row<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
#define __UNPACK_TYPE__(PACKETNAME)
Array< int, Dynamic, 1 > v
RowXpr row(Index i)
This is the const version of row(). */.
Array< double, 1, 3 > e(1./3., 0.5, 2.)
#define EIGEN_ALWAYS_INLINE
#define eigen_internal_assert(x)
#define EIGEN_UNUSED_VARIABLE(var)
static Packet16uc p16uc_MERGE16_32_V1
EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_real(RealPacket &a, RealPacket &b, RealPacket &c)
EIGEN_ALWAYS_INLINE LhsPacket loadLhsPacket(LhsMapper &lhs, Index i, Index j)
EIGEN_ALWAYS_INLINE Packet2cf pcplxflipnegate(Packet2cf a)
const Packet16uc p16uc_COMPLEX64_XORFLIP
EIGEN_ALWAYS_INLINE Packet4f pload_real(float *src)
EIGEN_ALWAYS_INLINE Packet2cf pcplxconjflip(Packet2cf a)
#define MAX_BFLOAT16_VEC_ACC_VSX
void gemv_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtra(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
const Packet16uc p16uc_COMPLEX32_CONJ_XOR
EIGEN_ALWAYS_INLINE Packet2cf pnegate2(Packet2cf a)
EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_complex(RealPacket &a, RealPacket &b, RealPacket &c)
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16()
EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf &a, std::complex< float > &b)
EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex(LhsPacket &a0, RhsScalar *b, PResPacket &c0, ResPacket &c1)
EIGEN_ALWAYS_INLINE Scalar pset1_realimag(ResScalar &alpha, int which, int conj)
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16 *dst, Packet8bf data, Index resInc, Index extra)
EIGEN_ALWAYS_INLINE Packet2cf pcplxflip2(Packet2cf a)
EIGEN_ALWAYS_INLINE Packet pset_init(Packet &c1)
const Packet16uc p16uc_MERGEE
EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f(&acc)[num_acc][size], float *result, Packet4f pAlpha, Index extra_rows)
#define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter)
EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtraN(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
#define GEMV_PROCESS_ROW_COMPLEX(N)
#define GEMV_PROCESS_COL_COMPLEX(N)
EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f(&acc)[num_acc][2], Packet4f(&a0)[num_acc][2], Packet4f(&b0)[2])
EIGEN_ALWAYS_INLINE Packet4f pload_complex_half(std::complex< float > *src)
#define GEMV_PROCESS_ROW(N)
EIGEN_ALWAYS_INLINE ScalarPacket pmadd_complex(ScalarPacket &c0, ScalarPacket &c2, ScalarPacket &c4, AlphaData &b0)
#define GEMV_IS_COMPLEX_FLOAT
EIGEN_ALWAYS_INLINE void multVSXVecLoop(Packet4f(&acc)[num_acc][2], const LhsMapper &lhs, RhsMapper &rhs, Index j, Index extra_cols)
const Packet16uc p16uc_COMPLEX64_CONJ_XOR2
EIGEN_ALWAYS_INLINE Packet4f pload_real_full(float *src)
EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f(&acc)[num_acc][2])
EIGEN_ALWAYS_INLINE Packet2cf pconj2(const Packet2cf &a)
EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine_row(std::complex< float > *src)
#define GEMV_PROCESS_COL(N)
const Packet16uc p16uc_COMPLEX32_CONJ_XOR2
EIGEN_ALWAYS_INLINE Packet2cf pconjinv(const Packet2cf &a)
EIGEN_ALWAYS_INLINE Packet4f pload_complex_full(std::complex< float > *src)
static Packet16uc p16uc_MERGE16_32_V2
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar)
EIGEN_ALWAYS_INLINE Packet pset_zero()
EIGEN_ALWAYS_INLINE void loadVecLoopVSX(Index k, LhsMapper &lhs, Packet4f(&a0)[num_acc][2])
EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
void colVSXVecColLoopBody(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper &lhs, RhsMapper &rhs, Packet4f(&acc)[num_acc][2])
EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket &c0, AlphaData &b0, ResScalar *res)
EIGEN_ALWAYS_INLINE Packet1cd pset_zero< Packet1cd >()
EIGEN_ALWAYS_INLINE void calcVSXVecLoops(Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper &rhs, Index j)
EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar *src, Packet4f &r, Packet4f &i)
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f(&acc)[num_acc][size], float *result, Packet4f pAlpha)
EIGEN_ALWAYS_INLINE void pload_realimag(RhsScalar *src, Packet4f &r, Packet4f &i)
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16 *dst, Index resInc=1)
EIGEN_ALWAYS_INLINE Packet4f pload_complex(std::complex< float > *src)
EIGEN_ALWAYS_INLINE void preduxVecResults2VSX(Packet4f(&acc)[num_acc][2], Index k)
EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine(std::complex< float > *src)
void gemv_complex_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
EIGEN_ALWAYS_INLINE Packet2cf pset_zero< Packet2cf >()
void colVSXVecLoopBody(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
#define GEMV_PROCESS_COL_COMPLEX_ONE(N)
EIGEN_ALWAYS_INLINE Packet2cf pcplxflipconj(Packet2cf a)
#define GEMV_PROCESS_COL_ONE(N)
EIGEN_ALWAYS_INLINE void gemv_mult_complex_real(LhsPacket &a0, RhsScalar *b, PResPacket &c0)
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void preduxVecResultsVSX(Packet4f(&acc)[num_acc][2])
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index &i, float *result, Index rows, bfloat16 *&dst, Index resInc=1)
EIGEN_ALWAYS_INLINE void vecVSXLoop(Index cols, const LhsMapper &lhs, RhsMapper &rhs, Packet4f(&acc)[num_acc][2], Index extra_cols)
#define GEMV_MULT_COMPLEX_REAL(LhsType, RhsType, ResType1, ResType2)
EIGEN_ALWAYS_INLINE void gemv_mult_generic(LhsPacket &a0, RhsScalar *b, PResPacket &c0)
EIGEN_ALWAYS_INLINE Packet4f pload_real_row(float *src)
#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar)
const Packet16uc p16uc_COMPLEX32_NEGATE
EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar *res, ResPacket &palpha, ResPacket &data)
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar)
EIGEN_ALWAYS_INLINE Packet2cf pset1_complex(std::complex< float > &alpha)
EIGEN_ALWAYS_INLINE Packet4f pload_complex_full_row(std::complex< float > *src)
const Packet16uc p16uc_COMPLEX64_CONJ_XOR
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16()
EIGEN_ALWAYS_INLINE ScalarBlock< ResScalar, 2 > predux_complex(ResPacket &a, ResPacket &b)
EIGEN_ALWAYS_INLINE Packet8us loadPacketPartialZero(Packet8us data, Index extra_cols)
const Packet16uc p16uc_MERGEO
EIGEN_ALWAYS_INLINE ScalarBlock< ResScalar, 2 > predux_real(ResPacket &a, ResPacket &b)
const Packet16uc p16uc_COMPLEX64_NEGATE
#define GEMV_BUILDPAIR_MMA(dst, src1, src2)
const Packet16uc p16uc_COMPLEX32_XORFLIP
#define GEMV_MULT_REAL_COMPLEX(LhsType, RhsType, ResType)
void gemv_bfloat16_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
#define GEMV_IS_COMPLEX_COMPLEX
void gemv_bfloat16_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
void gemv_complex_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar)
EIGEN_ALWAYS_INLINE void gemv_mult_real_complex(LhsPacket &a0, RhsScalar *b, PResPacket &c0)
void gemv_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
#define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N)
#define GEMV_MULT_COMPLEX_COMPLEX(LhsType, RhsType, ResType)
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
float bfloat16_to_float(__bfloat16_raw h)
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data)
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
unpacket_traits< Packet >::type predux(const Packet &a)
__vector unsigned char Packet16uc
Packet16uc pset1< Packet16uc >(const unsigned char &from)
Packet2cf ploadu< Packet2cf >(const std::complex< float > *from)
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data)
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
__vector unsigned short int Packet8us
Packet1cd ploadu< Packet1cd >(const std::complex< double > *from)
Packet1cd pcplxflip(const Packet1cd &x)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Packet2d ploaddup< Packet2d >(const double *from)
void pstoreu(Scalar *to, const Packet &from)
void pstoreu_partial(Scalar *to, const Packet &from, const Index n, const Index offset=0)
void pscatter(Scalar *to, const Packet &from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
void pscatter_partial(Scalar *to, const Packet &from, Index stride, const Index n)
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Packet4f ploadu< Packet4f >(const float *from)
Packet2d pset1< Packet2d >(const double &from)
Packet2d ploadu< Packet2d >(const double *from)
Packet8h pxor(const Packet8h &a, const Packet8h &b)
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float *res)
Packet4f pset1< Packet4f >(const float &from)
EIGEN_ALWAYS_INLINE Packet8bf pgather< bfloat16, Packet8bf >(const bfloat16 *from, Index stride)
static Packet16uc p16uc_TRANSPOSE64_HI
EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)
struct alpha_store::ri separate
alpha_store(ResScalar &alpha)
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper &rhs, Index j)
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper &rhs, Index j)