1 #ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
2 #define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
5 #define BFLOAT16_UNROLL _Pragma("unroll 8")
7 #define BFLOAT16_UNROLL _Pragma("GCC unroll(8)")
20 return vec_mergeh(lhs1.m_val, lhs2.m_val);
29 return loadBfloat16<zero>(blockB + strideB*
i);
32 template<Index num_acc, Index num_packets,
bool zero,
bool rhsExtraCols,
bool lhsExtraRows, Index num_rhs, Index num_lhs>
37 __vector_quad (&quad_acc)[num_acc],
48 for(
Index i = 0;
i < (num_rhs - (rhsExtraCols ? 1 : 0));
i++){
49 rhs[
i] = loadRhsBfloat16<zero>(indexB + k*4, strideB,
i);
52 rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_rhs - 1);
55 indexA += k*(lhsExtraRows ? extra_rows : num_packets);
57 lhs[0] = loadBfloat16<zero>(indexA);
60 for(
Index j = 0;
j < num_lhs;
j += 2) {
64 lhs[
j + 0] = vec_mergeh(lhs1.m_val, lhs2.m_val);
65 lhs[
j + 1] = vec_mergel(lhs1.m_val, lhs2.m_val);
74 for(
Index i = 0,
x = 0;
i < num_rhs;
i++) {
77 __builtin_mma_xvbf16ger2pp(&(quad_acc[
x]),
reinterpret_cast<Packet16uc>(rhs[
i].m_val),
reinterpret_cast<Packet16uc>(lhs[
j].m_val));
82 template<Index num_acc>
86 for(
Index k = 0; k < num_acc; k++)
87 __builtin_mma_xxsetaccz(&(quad_acc[k]));
90 template<Index num_acc>
94 for(
Index k = 0; k < num_acc; k++)
95 __builtin_mma_disassemble_acc((
void*)acc[k], &(quad_acc[k]));
98 template<Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows, Index num_rhs, Index num_lhs>
102 for(
Index i = 0, k = 0;
i < num_rhs - (rhsExtraCols ? 1 : 0);
i++, result += 4*
rows){
104 for(
Index j = 0;
j < num_lhs;
j++, k++) {
105 storeResults<false, lhsExtraRows>(acc[k],
rows, pAlpha, result +
j*4, extra_cols, extra_rows);
109 storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1],
rows, pAlpha, result, extra_cols, extra_rows);
113 template<const Index num_acc, const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows,
bool multiIter = false>
116 constexpr
Index num_lhs = multiIter ? (num_packets / 4) : 1;
117 constexpr
Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
119 for(
Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8), indexB += (multiIter ? (num_rhs*strideB) : 0), result += (multiIter ? (4*
rows*num_rhs) : 4)) {
121 __vector_quad quad_acc[num_acc];
123 zeroAccumulators<num_acc>(quad_acc);
126 for(k = 0; k + 2 <= depth; k += 2){
127 KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
130 KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
133 disassembleAccumulators<num_acc>(quad_acc, acc);
135 outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc,
rows, pAlpha, result, extra_cols, extra_rows);
139 #define MAX_BFLOAT16_ACC 8
141 template<const Index num_acc, const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows>
144 constexpr
Index step = (num_acc * 4);
145 const Index extra_cols = (rhsExtraCols) ? (
cols & 3) : 0;
146 const Index extra_rows = (lhsExtraRows) ? (
rows & 3) : 0;
148 constexpr
bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0);
151 colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, normIters>(depth,
rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
153 indexB += strideB*num_acc;
155 }
while(multiIters && (step <=
cols - (
col += step)));
158 template<const Index num_acc, const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows>
162 colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
166 template<const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows>
171 colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
174 colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
177 colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
180 colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
183 colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
186 colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
189 colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
193 colLoopBody<1, num_packets, true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
199 template<const Index num_packets,
bool lhsExtraRows = false>
204 colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, 0, result);
205 blockB += (strideB >> 2)*
col;
209 colLoopBodyExtra<num_packets, true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
211 colLoopBodyExtra<num_packets, false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, 0, result);
218 __vector_pair fp16_vp = *
reinterpret_cast<__vector_pair *
>(
const_cast<float *
>(
res));
219 __builtin_vsx_disassemble_pair(
reinterpret_cast<void*
>(fp16), &fp16_vp);
220 fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
221 fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
222 return vec_pack(
reinterpret_cast<Packet4ui>(fp16[0]),
reinterpret_cast<Packet4ui>(fp16[1]));
225 template<
typename DataMapper, const Index size>
228 const DataMapper res2 =
res.getSubMapper(0,
col);
230 float *result2 = result +
col*
rows;
233 PacketBlock<Packet8bf,size>
block;
238 res2.template storePacketBlock<Packet8bf,size>(
row, 0,
block);
245 res2.template storePacketPartial<Packet8bf>(
row,
j, fp16,
rows & 7);
250 template<const Index size,
bool non_unit_str
ide = false>
264 storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc,
rows & 7);
266 storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
269 storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
270 storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
272 i += extra; dst += extra*resInc;
273 if (
size != 32)
break;
277 template<
bool non_unit_str
ide = false>
281 convertPointerF32toBF16<32,non_unit_stride>(
i, result,
rows, dst, resInc);
282 convertPointerF32toBF16<16,non_unit_stride>(
i, result,
rows, dst, resInc);
283 convertPointerF32toBF16<8,non_unit_stride>(
i, result,
rows, dst, resInc);
284 convertPointerF32toBF16<1,non_unit_stride>(
i, result,
rows, dst, resInc);
287 template<
typename DataMapper>
292 convertArrayF32toBF16Col<DataMapper,4>(result,
col,
rows,
res);
297 convertArrayF32toBF16Col<DataMapper,1>(result,
col,
rows,
res);
300 convertArrayF32toBF16Col<DataMapper,2>(result,
col,
rows,
res);
303 convertArrayF32toBF16Col<DataMapper,3>(result,
col,
rows,
res);
309 EIGEN_ALWAYS_INLINE void calcColLoops(
const bfloat16*& indexA,
Index&
row,
Index depth,
Index cols,
Index rows,
const Packet4f pAlpha,
const bfloat16* indexB,
Index strideB,
Index offsetA,
Index offsetB,
Index bigSuffix,
float *result)
312 indexA +=
size*offsetA;
313 colLoops<size>(depth,
cols,
rows, pAlpha, indexA, indexB, strideB, offsetB, result +
row);
315 indexA += bigSuffix*
size/16;
319 template<
typename DataMapper>
320 void gemmMMAbfloat16(
const DataMapper&
res,
const bfloat16* indexA,
const bfloat16* indexB,
Index rows,
Index depth,
Index cols,
bfloat16 alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
326 convertArrayBF16toF32<DataMapper>(result,
cols,
rows,
res);
328 if( strideA == -1 ) strideA = depth;
329 if( strideB == -1 ) strideB = depth;
338 Index bigSuffix = (2*8) * (strideA-offsetA);
345 calcColLoops<16>(indexA,
row, depth,
cols,
rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
348 calcColLoops<8>(indexA,
row, depth,
cols,
rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
350 calcColLoops<4>(indexA,
row, depth,
cols,
rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
354 colLoops<4, true>(depth,
cols,
rows, pAlpha, indexA, indexB, strideB, offsetB, result +
row);
358 convertArrayF32toBF16<DataMapper>(result,
cols,
rows,
res);
361 #undef MAX_BFLOAT16_ACC
363 #if !EIGEN_ALTIVEC_DISABLE_MMA
364 template<Index num_acc,
typename LhsMapper,
bool zero>
367 a0[k + 0] = lhs.template loadPacket<Packet8bf>(k*4, 0);
369 b1 = lhs.template loadPacket<Packet8bf>(k*4, 1);
371 if (num_acc > (k + 1)) {
372 a0[k + 1] = vec_mergel(a0[k + 0].m_val, b1.m_val);
374 a0[k + 0] = vec_mergeh(a0[k + 0].m_val, b1.m_val);
377 template<Index num_acc>
381 for(
Index k = 0; k < num_acc; k++) {
382 __builtin_mma_xvbf16ger2pp(&(quad_acc[k]),
reinterpret_cast<Packet16uc>(b0.m_val),
reinterpret_cast<Packet16uc>(a0[k].m_val));
386 template<Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool zero,
bool linear>
391 Packet8bf b0 = loadColData<RhsMapper, linear>(rhs,
j);
394 b0 = vec_mergeh(b0.m_val, b1.m_val);
397 LhsMapper lhs2 = lhs.getSubMapper(0,
j);
399 for(
Index k = 0; k < num_acc; k += 2) {
400 loadVecLoop<num_acc, LhsMapper, zero>(k, lhs2, a0, b1);
403 multVec<num_acc>(quad_acc, a0, b0);
406 #define MAX_BFLOAT16_VEC_ACC 8
408 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
411 constexpr
Index step = (num_acc * 4);
412 const Index extra_rows = (extraRows) ? (
rows & 3) : 0;
417 __vector_quad quad_acc[num_acc];
419 zeroAccumulators<num_acc>(quad_acc);
421 LhsMapper lhs2 = lhs.getSubMapper(
row, 0);
422 for(
Index j = 0;
j + 2 <= cend;
j += 2) {
423 vecColLoop<num_acc, LhsMapper, RhsMapper, false, linear>(
j, lhs2, rhs, quad_acc);
426 vecColLoop<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
429 disassembleAccumulators<num_acc>(quad_acc, acc);
431 outputVecColResults<num_acc, extraRows>(acc, result, pAlpha, extra_rows);
434 }
while(multiIters && (step <=
rows - (
row += step)));
437 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
441 colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
445 template<
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
450 colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
453 colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
456 colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
459 colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
462 colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
465 colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
468 colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
472 colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
478 template<
typename LhsMapper,
typename RhsMapper,
bool linear>
483 colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
487 colVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
489 colVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
493 template<
typename LhsMapper,
typename RhsMapper>
496 const LhsMapper& alhs,
497 const RhsMapper& rhs,
501 typedef typename RhsMapper::LinearMapper LinearMapper;
511 const Index lhsStride = lhs.stride();
522 for (
Index j2 = 0; j2 <
cols; j2 += block_cols)
526 LhsMapper lhs2 = lhs.getSubMapper(0, j2);
527 if (rhs.stride() == 1) {
528 LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
529 calcVecColLoops<LhsMapper, LinearMapper, true>(jend - j2,
rows, lhs2, rhs3, pAlpha, result);
531 RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
532 calcVecColLoops<LhsMapper, RhsMapper, false>(jend - j2,
rows, lhs2, rhs3, pAlpha, result);
539 static Packet16uc p16uc_ELEMENT_VEC3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
541 template<Index num_acc>
544 if (num_acc > (k + 1)) {
545 acc[k][0] = vec_mergeh(acc[k][0], acc[k + 1][0]);
546 acc[k][1] = vec_mergeo(acc[k][1], acc[k + 1][1]);
547 acc[k][2] = vec_mergel(acc[k][2], acc[k + 1][2]);
550 acc[k][0] = (acc[k][0] + acc[k][2]) + (acc[k][1] + acc[k][3]);
552 acc[k][0] = vec_mergeh(acc[k][0], acc[k][1]);
553 acc[k][0] += vec_mergel(acc[k][2], acc[k][3]);
555 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
557 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
562 template<Index num_acc>
566 for(
Index k = 0; k < num_acc; k += 4) {
567 preduxVecResults2<num_acc>(acc, k + 0);
568 if (num_acc > (k + 2)) {
569 preduxVecResults2<num_acc>(acc, k + 2);
570 acc[k + 0][0] =
reinterpret_cast<Packet4f>(vec_mergeh(
reinterpret_cast<Packet2ul>(acc[k + 0][0]),
reinterpret_cast<Packet2ul>(acc[k + 2][0])));
575 template<Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extra>
581 b0 = rhs.template loadPacketPartial<Packet8bf>(
j, extra_cols);
583 b0 = rhs.template loadPacket<Packet8bf>(
j);
586 const LhsMapper lhs2 = lhs.getSubMapper(0,
j);
588 for(
Index k = 0; k < num_acc; k++) {
590 a0[k] = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
592 a0[k] = lhs2.template loadPacket<Packet8bf>(k, 0);
596 multVec<num_acc>(quad_acc, a0, b0);
599 template<Index num_acc,
typename LhsMapper,
typename RhsMapper>
603 for(;
j + 8 <=
cols;
j += 8){
604 multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs,
j, extra_cols);
608 multVecLoop<num_acc, LhsMapper, RhsMapper, true>(quad_acc, lhs, rhs,
j, extra_cols);
612 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper>
620 __vector_quad quad_acc[num_acc];
622 zeroAccumulators<num_acc>(quad_acc);
624 const LhsMapper lhs2 = lhs.getSubMapper(
row, 0);
625 vecLoop<num_acc, LhsMapper, RhsMapper>(
cols, lhs2, rhs, quad_acc, extra_cols);
627 disassembleAccumulators<num_acc>(quad_acc, acc);
629 preduxVecResults<num_acc>(acc);
631 outputVecResults<num_acc>(acc, result, pAlpha);
634 }
while(multiIters && (num_acc <=
rows - (
row += num_acc)));
637 template<const Index num_acc,
typename LhsMapper,
typename RhsMapper>
641 colVecLoopBody<num_acc, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
645 template<
typename LhsMapper,
typename RhsMapper>
650 colVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
653 colVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
656 colVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
659 colVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
662 colVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
665 colVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
668 colVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
673 template<
typename LhsMapper,
typename RhsMapper>
678 colVecLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
681 colVecLoopBodyExtra<LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
684 template<
typename LhsMapper,
typename RhsMapper>
687 const LhsMapper& alhs,
688 const RhsMapper& rhs,
692 typedef typename RhsMapper::LinearMapper LinearMapper;
697 LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
708 convertArrayPointerBF16toF32<true>(result, 1,
rows,
res, resIncr);
710 calcVecLoops<LhsMapper, LinearMapper>(
cols,
rows, lhs, rhs2, pAlpha, result);
714 convertArrayPointerF32toBF16<true>(result,
rows,
res, resIncr);
719 #undef MAX_BFLOAT16_VEC_ACC
720 #undef BFLOAT16_UNROLL
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL FixedBlockXpr<...,... >::Type block(Index startRow, Index startCol, NRowsType blockRows, NColsType blockCols)
RowXpr row(Index i)
This is the const version of row(). */.
ColXpr col(Index i)
This is the const version of col().
#define EIGEN_ALWAYS_INLINE
#define eigen_internal_assert(x)
#define EIGEN_UNUSED_VARIABLE(var)
#define MAX_BFLOAT16_VEC_ACC
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
float bfloat16_to_float(__bfloat16_raw h)
static Packet16uc p16uc_ELEMENT_VEC3
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
__vector unsigned char Packet16uc
EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f(&acc)[num_acc][4])
void gemmMMAbfloat16(const DataMapper &res, const bfloat16 *indexA, const bfloat16 *indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16 *indexA)
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
void colVecColLoopBody(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper &res)
void colVecLoopBody(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad(&quad_acc)[num_acc], const LhsMapper &lhs, RhsMapper &rhs, Index j, Index extra_cols)
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper &res)
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16 *blockB, Index strideB, Index i)
__vector unsigned int Packet4ui
EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f(&acc)[num_acc][4], Index k)
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
EIGEN_ALWAYS_INLINE void outputResults(Packet4f(&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float *result, const Index extra_cols, Index extra_rows)
void colLoopBody(Index &col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result)
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc])
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f(&acc)[num_acc][size])
EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result, const Index extra_cols, const Index extra_rows)
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc], Index extra_cols)
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16 *dst, Index resInc=1)
Packet8bf ploadu< Packet8bf >(const bfloat16 *from)
EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16 *&indexA, Index &row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
Packet8bf pset1< Packet8bf >(const bfloat16 &from)
EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad(&quad_acc)[num_acc], Packet4f(&acc)[num_acc][4])
EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper &lhs, Packet8bf(&a0)[num_acc], Packet8bf b1)
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index &i, float *result, Index rows, bfloat16 *&dst, Index resInc=1)
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Packet4f pset1< Packet4f >(const float &from)
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
EIGEN_ALWAYS_INLINE void KLoop(const float *indexA, const float *indexB, Packet4f(&acc)[num_acc][4], Index strideB, Index k, Index offsetB, Index extra_cols)
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void multVec(__vector_quad(&quad_acc)[num_acc], Packet8bf(&a0)[num_acc], Packet8bf b0)
void gemvMMA_bfloat16_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
void gemvMMA_bfloat16_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.