MatrixProductMMAbfloat16.h
Go to the documentation of this file.
1 #ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
2 #define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
3 
4 #if EIGEN_COMP_LLVM
5 #define BFLOAT16_UNROLL _Pragma("unroll 8")
6 #else
7 #define BFLOAT16_UNROLL _Pragma("GCC unroll(8)")
8 #endif
9 
10 namespace Eigen {
11 
12 namespace internal {
13 
14 template<bool zero>
16 {
17  Packet8bf lhs1 = ploadu<Packet8bf>(indexA);
18  if(zero){
20  return vec_mergeh(lhs1.m_val, lhs2.m_val);
21  } else {
22  return lhs1;
23  }
24 }
25 
26 template<bool zero>
28 {
29  return loadBfloat16<zero>(blockB + strideB*i);
30 }
31 
32 template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
34 (
35  const bfloat16* indexA,
36  const bfloat16* indexB,
37  __vector_quad (&quad_acc)[num_acc],
38  Index strideB,
39  Index k,
40  Index offsetB,
41  Index extra_cols,
42  Index extra_rows
43 )
44 {
45  Packet8bf lhs[num_lhs], rhs[num_rhs];
46 
48  for(Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++){
49  rhs[i] = loadRhsBfloat16<zero>(indexB + k*4, strideB, i);
50  }
51  if(rhsExtraCols) {
52  rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_rhs - 1);
53  }
54 
55  indexA += k*(lhsExtraRows ? extra_rows : num_packets);
56  if (num_lhs == 1) {
57  lhs[0] = loadBfloat16<zero>(indexA);
58  } else {
60  for(Index j = 0; j < num_lhs; j += 2) {
61  Packet8bf lhs1 = ploadu<Packet8bf>(indexA + (j + 0)*(zero ? 4 : 8));
62  if (zero) {
64  lhs[j + 0] = vec_mergeh(lhs1.m_val, lhs2.m_val);
65  lhs[j + 1] = vec_mergel(lhs1.m_val, lhs2.m_val);
66  } else {
67  lhs[j + 0] = lhs1;
68  lhs[j + 1] = ploadu<Packet8bf>(indexA + (j + 1)*8);
69  }
70  }
71  }
72 
74  for(Index i = 0, x = 0; i < num_rhs; i++) {
76  for(Index j = 0; j < num_lhs; j++, x++) {
77  __builtin_mma_xvbf16ger2pp(&(quad_acc[x]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs[j].m_val));
78  }
79  }
80 }
81 
82 template<Index num_acc>
83 EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc])
84 {
86  for(Index k = 0; k < num_acc; k++)
87  __builtin_mma_xxsetaccz(&(quad_acc[k]));
88 }
89 
90 template<Index num_acc>
91 EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_acc], Packet4f (&acc)[num_acc][4])
92 {
94  for(Index k = 0; k < num_acc; k++)
95  __builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
96 }
97 
98 template<Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
99 EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows)
100 {
102  for(Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4*rows){
104  for(Index j = 0; j < num_lhs; j++, k++) {
105  storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + j*4, extra_cols, extra_rows);
106  }
107  }
108  if(rhsExtraCols) {
109  storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1], rows, pAlpha, result, extra_cols, extra_rows);
110  }
111 }
112 
113 template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows, bool multiIter = false>
114 EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, const Index extra_cols, const Index extra_rows)
115 {
116  constexpr Index num_lhs = multiIter ? (num_packets / 4) : 1;
117  constexpr Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
118 
119  for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8), indexB += (multiIter ? (num_rhs*strideB) : 0), result += (multiIter ? (4*rows*num_rhs) : 4)) {
120  Packet4f acc[num_acc][4];
121  __vector_quad quad_acc[num_acc];
122 
123  zeroAccumulators<num_acc>(quad_acc);
124 
125  Index k;
126  for(k = 0; k + 2 <= depth; k += 2){
127  KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
128  }
129  if(depth&1){
130  KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
131  }
132 
133  disassembleAccumulators<num_acc>(quad_acc, acc);
134 
135  outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
136  }
137 }
138 
139 #define MAX_BFLOAT16_ACC 8
140 
141 template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
142 void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result)
143 {
144  constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
145  const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
146  const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
147  constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
148  constexpr bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0);
149 
150  do{
151  colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, normIters>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
152 
153  indexB += strideB*num_acc;
154  result += rows*step;
155  } while(multiIters && (step <= cols - (col += step)));
156 }
157 
158 template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
159 EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
160 {
161  if (MAX_BFLOAT16_ACC > num_acc) {
162  colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
163  }
164 }
165 
166 template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
167 void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
168 {
169  switch ((cols - col) >> 2) {
170  case 7:
171  colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
172  break;
173  case 6:
174  colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
175  break;
176  case 5:
177  colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
178  break;
179  case 4:
180  colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
181  break;
182  case 3:
183  colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
184  break;
185  case 2:
186  colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
187  break;
188  case 1:
189  colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
190  break;
191  default:
192  if (rhsExtraCols) {
193  colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
194  }
195  break;
196  }
197 }
198 
199 template<const Index num_packets, bool lhsExtraRows = false>
200 EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
201 {
202  Index col = 0;
203  if (cols >= (MAX_BFLOAT16_ACC * 4)) {
204  colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
205  blockB += (strideB >> 2)*col;
206  result += rows*col;
207  }
208  if (cols & 3) {
209  colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
210  } else {
211  colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
212  }
213 }
214 
216 {
217  Packet16uc fp16[2];
218  __vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast<float *>(res));
219  __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
220  fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
221  fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
222  return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
223 }
224 
225 template<typename DataMapper, const Index size>
226 EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper& res)
227 {
228  const DataMapper res2 = res.getSubMapper(0, col);
229  Index row;
230  float *result2 = result + col*rows;
231  for(row = 0; row + 8 <= rows; row += 8, result2 += 8){
232  // get and save block
233  PacketBlock<Packet8bf,size> block;
235  for(Index j = 0; j < size; j++){
236  block.packet[j] = convertF32toBF16(result2 + j*rows);
237  }
238  res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
239  }
240  // extra rows
241  if(row < rows){
243  for(Index j = 0; j < size; j++){
244  Packet8bf fp16 = convertF32toBF16(result2 + j*rows);
245  res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
246  }
247  }
248 }
249 
250 template<const Index size, bool non_unit_stride = false>
251 EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
252 {
253  constexpr Index extra = ((size < 8) ? 8 : size);
254  while (i + size <= rows){
255  PacketBlock<Packet8bf,(size+7)/8> r32;
256  r32.packet[0] = convertF32toBF16(result + i + 0);
257  if (size >= 16) {
258  r32.packet[1] = convertF32toBF16(result + i + 8);
259  }
260  if (size >= 32) {
261  r32.packet[2] = convertF32toBF16(result + i + 16);
262  r32.packet[3] = convertF32toBF16(result + i + 24);
263  }
264  storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc, rows & 7);
265  if (size >= 16) {
266  storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
267  }
268  if (size >= 32) {
269  storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
270  storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
271  }
272  i += extra; dst += extra*resInc;
273  if (size != 32) break;
274  }
275 }
276 
277 template<bool non_unit_stride = false>
279 {
280  Index i = 0;
281  convertPointerF32toBF16<32,non_unit_stride>(i, result, rows, dst, resInc);
282  convertPointerF32toBF16<16,non_unit_stride>(i, result, rows, dst, resInc);
283  convertPointerF32toBF16<8,non_unit_stride>(i, result, rows, dst, resInc);
284  convertPointerF32toBF16<1,non_unit_stride>(i, result, rows, dst, resInc);
285 }
286 
287 template<typename DataMapper>
288 EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res)
289 {
290  Index col;
291  for(col = 0; col + 4 <= cols; col += 4){
292  convertArrayF32toBF16Col<DataMapper,4>(result, col, rows, res);
293  }
294  // extra cols
295  switch (cols - col) {
296  case 1:
297  convertArrayF32toBF16Col<DataMapper,1>(result, col, rows, res);
298  break;
299  case 2:
300  convertArrayF32toBF16Col<DataMapper,2>(result, col, rows, res);
301  break;
302  case 3:
303  convertArrayF32toBF16Col<DataMapper,3>(result, col, rows, res);
304  break;
305  }
306 }
307 
308 template<Index size>
309 EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16*& indexA, Index& row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
310 {
311  if ((size == 16) || (rows & size)) {
312  indexA += size*offsetA;
313  colLoops<size>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
314  row += size;
315  indexA += bigSuffix*size/16;
316  }
317 }
318 
319 template<typename DataMapper>
320 void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
321 {
322  float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
323  const Packet4f pAlpha = pset1<Packet4f>(falpha);
325 
326  convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
327 
328  if( strideA == -1 ) strideA = depth;
329  if( strideB == -1 ) strideB = depth;
330  // Packing is done in blocks.
331  // There's 4 possible sizes of blocks
332  // Blocks of 8 columns with 16 elements (8x16)
333  // Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
334  // Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
335  // Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
336 
337  // Loop for LHS standard block (8x16)
338  Index bigSuffix = (2*8) * (strideA-offsetA);
339  indexB += 4*offsetB;
340  strideB *= 4;
341  offsetB *= 3;
342 
343  Index row = 0;
344  while(row + 16 <= rows){
345  calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
346  }
347  // LHS (8x8) block
348  calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
349  // LHS (8x4) block
350  calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
351  // extra rows
352  if(rows & 3){
353  // This index is the beginning of remaining block.
354  colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
355  }
356 
357  // Convert back to bfloat16
358  convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
359 }
360 
361 #undef MAX_BFLOAT16_ACC
362 
363 #if !EIGEN_ALTIVEC_DISABLE_MMA
364 template<Index num_acc, typename LhsMapper, bool zero>
365 EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1)
366 {
367  a0[k + 0] = lhs.template loadPacket<Packet8bf>(k*4, 0);
368  if (!zero) {
369  b1 = lhs.template loadPacket<Packet8bf>(k*4, 1);
370  }
371  if (num_acc > (k + 1)) {
372  a0[k + 1] = vec_mergel(a0[k + 0].m_val, b1.m_val);
373  }
374  a0[k + 0] = vec_mergeh(a0[k + 0].m_val, b1.m_val);
375 }
376 
377 template<Index num_acc>
378 EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (&a0)[num_acc], Packet8bf b0)
379 {
381  for(Index k = 0; k < num_acc; k++) {
382  __builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(b0.m_val), reinterpret_cast<Packet16uc>(a0[k].m_val));
383  }
384 }
385 
386 template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
387 EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc])
388 {
389  Packet8bf a0[num_acc];
391  Packet8bf b0 = loadColData<RhsMapper, linear>(rhs, j);
392 
393  if (zero) {
394  b0 = vec_mergeh(b0.m_val, b1.m_val);
395  }
396 
397  LhsMapper lhs2 = lhs.getSubMapper(0, j);
399  for(Index k = 0; k < num_acc; k += 2) {
400  loadVecLoop<num_acc, LhsMapper, zero>(k, lhs2, a0, b1);
401  }
402 
403  multVec<num_acc>(quad_acc, a0, b0);
404 }
405 
406 #define MAX_BFLOAT16_VEC_ACC 8
407 
408 template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
409 void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
410 {
411  constexpr Index step = (num_acc * 4);
412  const Index extra_rows = (extraRows) ? (rows & 3) : 0;
413  constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC);
414 
415  do{
416  Packet4f acc[num_acc][4];
417  __vector_quad quad_acc[num_acc];
418 
419  zeroAccumulators<num_acc>(quad_acc);
420 
421  LhsMapper lhs2 = lhs.getSubMapper(row, 0);
422  for(Index j = 0; j + 2 <= cend; j += 2) {
423  vecColLoop<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
424  }
425  if (cend & 1) {
426  vecColLoop<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
427  }
428 
429  disassembleAccumulators<num_acc>(quad_acc, acc);
430 
431  outputVecColResults<num_acc, extraRows>(acc, result, pAlpha, extra_rows);
432 
433  result += step;
434  } while(multiIters && (step <= rows - (row += step)));
435 }
436 
437 template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
438 EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
439 {
440  if (MAX_BFLOAT16_VEC_ACC > num_acc) {
441  colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
442  }
443 }
444 
445 template<typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
446 EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
447 {
448  switch ((rows - row) >> 2) {
449  case 7:
450  colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
451  break;
452  case 6:
453  colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
454  break;
455  case 5:
456  colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
457  break;
458  case 4:
459  colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
460  break;
461  case 3:
462  colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
463  break;
464  case 2:
465  colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
466  break;
467  case 1:
468  colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
469  break;
470  default:
471  if (extraRows) {
472  colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
473  }
474  break;
475  }
476 }
477 
478 template<typename LhsMapper, typename RhsMapper, bool linear>
479 EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
480 {
481  Index row = 0;
482  if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) {
483  colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
484  result += row;
485  }
486  if (rows & 3) {
487  colVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
488  } else {
489  colVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
490  }
491 }
492 
493 template<typename LhsMapper, typename RhsMapper>
495  Index rows, Index cols,
496  const LhsMapper& alhs,
497  const RhsMapper& rhs,
498  bfloat16* res, Index resIncr,
499  bfloat16 alpha)
500 {
501  typedef typename RhsMapper::LinearMapper LinearMapper;
502 
503  EIGEN_UNUSED_VARIABLE(resIncr);
504  eigen_internal_assert(resIncr == 1);
505 
506  // The following copy tells the compiler that lhs's attributes are not modified outside this function
507  // This helps GCC to generate proper code.
508  LhsMapper lhs(alhs);
509  RhsMapper rhs2(rhs);
510 
511  const Index lhsStride = lhs.stride();
512 
513  // TODO: improve the following heuristic:
514  const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8);
515  float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
516  Packet4f pAlpha = pset1<Packet4f>(falpha);
517 
519 
521 
522  for (Index j2 = 0; j2 < cols; j2 += block_cols)
523  {
524  Index jend = numext::mini(j2 + block_cols, cols);
525 
526  LhsMapper lhs2 = lhs.getSubMapper(0, j2);
527  if (rhs.stride() == 1) {
528  LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
529  calcVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
530  } else {
531  RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
532  calcVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
533  }
534  }
535 
537 }
538 
539 static Packet16uc p16uc_ELEMENT_VEC3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
540 
541 template<Index num_acc>
543 {
544  if (num_acc > (k + 1)) {
545  acc[k][0] = vec_mergeh(acc[k][0], acc[k + 1][0]);
546  acc[k][1] = vec_mergeo(acc[k][1], acc[k + 1][1]);
547  acc[k][2] = vec_mergel(acc[k][2], acc[k + 1][2]);
548  acc[k][3] = vec_perm(acc[k][3], acc[k + 1][3], p16uc_ELEMENT_VEC3);
549 
550  acc[k][0] = (acc[k][0] + acc[k][2]) + (acc[k][1] + acc[k][3]);
551  } else {
552  acc[k][0] = vec_mergeh(acc[k][0], acc[k][1]);
553  acc[k][0] += vec_mergel(acc[k][2], acc[k][3]);
554 #ifdef _BIG_ENDIAN
555  acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
556 #else
557  acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
558 #endif
559  }
560 }
561 
562 template<Index num_acc>
564 {
566  for(Index k = 0; k < num_acc; k += 4) {
567  preduxVecResults2<num_acc>(acc, k + 0);
568  if (num_acc > (k + 2)) {
569  preduxVecResults2<num_acc>(acc, k + 2);
570  acc[k + 0][0] = reinterpret_cast<Packet4f>(vec_mergeh(reinterpret_cast<Packet2ul>(acc[k + 0][0]), reinterpret_cast<Packet2ul>(acc[k + 2][0])));
571  }
572  }
573 }
574 
575 template<Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
576 EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols)
577 {
578  Packet8bf a0[num_acc], b0;
579 
580  if (extra) {
581  b0 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
582  } else {
583  b0 = rhs.template loadPacket<Packet8bf>(j);
584  }
585 
586  const LhsMapper lhs2 = lhs.getSubMapper(0, j);
588  for(Index k = 0; k < num_acc; k++) {
589  if (extra) {
590  a0[k] = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
591  } else {
592  a0[k] = lhs2.template loadPacket<Packet8bf>(k, 0);
593  }
594  }
595 
596  multVec<num_acc>(quad_acc, a0, b0);
597 }
598 
599 template<Index num_acc, typename LhsMapper, typename RhsMapper>
600 EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc], Index extra_cols)
601 {
602  Index j = 0;
603  for(; j + 8 <= cols; j += 8){
604  multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols);
605  }
606 
607  if (extra_cols) {
608  multVecLoop<num_acc, LhsMapper, RhsMapper, true>(quad_acc, lhs, rhs, j, extra_cols);
609  }
610 }
611 
612 template<const Index num_acc, typename LhsMapper, typename RhsMapper>
613 void colVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
614 {
615  constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC);
616  const Index extra_cols = (cols & 7);
617 
618  do{
619  Packet4f acc[num_acc][4];
620  __vector_quad quad_acc[num_acc];
621 
622  zeroAccumulators<num_acc>(quad_acc);
623 
624  const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
625  vecLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, quad_acc, extra_cols);
626 
627  disassembleAccumulators<num_acc>(quad_acc, acc);
628 
629  preduxVecResults<num_acc>(acc);
630 
631  outputVecResults<num_acc>(acc, result, pAlpha);
632 
633  result += num_acc;
634  } while(multiIters && (num_acc <= rows - (row += num_acc)));
635 }
636 
637 template<const Index num_acc, typename LhsMapper, typename RhsMapper>
638 EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
639 {
640  if (MAX_BFLOAT16_VEC_ACC > num_acc) {
641  colVecLoopBody<num_acc, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
642  }
643 }
644 
645 template<typename LhsMapper, typename RhsMapper>
646 EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
647 {
648  switch (rows - row) {
649  case 7:
650  colVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
651  break;
652  case 6:
653  colVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
654  break;
655  case 5:
656  colVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
657  break;
658  case 4:
659  colVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
660  break;
661  case 3:
662  colVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
663  break;
664  case 2:
665  colVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
666  break;
667  case 1:
668  colVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
669  break;
670  }
671 }
672 
673 template<typename LhsMapper, typename RhsMapper>
674 EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
675 {
676  Index row = 0;
677  if (rows >= MAX_BFLOAT16_VEC_ACC) {
678  colVecLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
679  result += row;
680  }
681  colVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
682 }
683 
684 template<typename LhsMapper, typename RhsMapper>
685 EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(
686  Index rows, Index cols,
687  const LhsMapper& alhs,
688  const RhsMapper& rhs,
689  bfloat16* res, Index resIncr,
690  bfloat16 alpha)
691 {
692  typedef typename RhsMapper::LinearMapper LinearMapper;
693 
694  // The following copy tells the compiler that lhs's attributes are not modified outside this function
695  // This helps GCC to generate proper code.
696  LhsMapper lhs(alhs);
697  LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
698 
699  eigen_internal_assert(rhs.stride() == 1);
700 
701  float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
702  const Packet4f pAlpha = pset1<Packet4f>(falpha);
703 
705  if (resIncr == 1) {
707  } else {
708  convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
709  }
710  calcVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
711  if (resIncr == 1) {
713  } else {
714  convertArrayPointerF32toBF16<true>(result, rows, res, resIncr);
715  }
716 }
717 #endif
718 
719 #undef MAX_BFLOAT16_VEC_ACC
720 #undef BFLOAT16_UNROLL
721 
722 }
723 }
724 #endif //EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL FixedBlockXpr<...,... >::Type block(Index startRow, Index startCol, NRowsType blockRows, NColsType blockCols)
Definition: BlockMethods.h:96
RowXpr row(Index i)
This is the const version of row(). *‍/.
ColXpr col(Index i)
This is the const version of col().
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:836
#define eigen_internal_assert(x)
Definition: Macros.h:908
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:957
#define MAX_BFLOAT16_ACC
#define MAX_BFLOAT16_VEC_ACC
#define BFLOAT16_UNROLL
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:847
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
float bfloat16_to_float(__bfloat16_raw h)
Definition: BFloat16.h:571
static Packet16uc p16uc_ELEMENT_VEC3
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
__vector unsigned char Packet16uc
EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f(&acc)[num_acc][4])
void gemmMMAbfloat16(const DataMapper &res, const bfloat16 *indexA, const bfloat16 *indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16 *indexA)
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
void colVecColLoopBody(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper &res)
void colVecLoopBody(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad(&quad_acc)[num_acc], const LhsMapper &lhs, RhsMapper &rhs, Index j, Index extra_cols)
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper &res)
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16 *blockB, Index strideB, Index i)
__vector unsigned int Packet4ui
EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f(&acc)[num_acc][4], Index k)
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
EIGEN_ALWAYS_INLINE void outputResults(Packet4f(&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float *result, const Index extra_cols, Index extra_rows)
void colLoopBody(Index &col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result)
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc])
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f(&acc)[num_acc][size])
EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result, const Index extra_cols, const Index extra_rows)
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc], Index extra_cols)
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16 *dst, Index resInc=1)
Packet8bf ploadu< Packet8bf >(const bfloat16 *from)
EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16 *&indexA, Index &row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
Packet8bf pset1< Packet8bf >(const bfloat16 &from)
EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad(&quad_acc)[num_acc], Packet4f(&acc)[num_acc][4])
EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper &lhs, Packet8bf(&a0)[num_acc], Packet8bf b1)
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index &i, float *result, Index rows, bfloat16 *&dst, Index resInc=1)
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Packet4f pset1< Packet4f >(const float &from)
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
__vector float Packet4f
EIGEN_ALWAYS_INLINE void KLoop(const float *indexA, const float *indexB, Packet4f(&acc)[num_acc][4], Index strideB, Index k, Index offsetB, Index extra_cols)
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
EIGEN_ALWAYS_INLINE void multVec(__vector_quad(&quad_acc)[num_acc], Packet8bf(&a0)[num_acc], Packet8bf b0)
void gemvMMA_bfloat16_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
void gemvMMA_bfloat16_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
std::ptrdiff_t j