TrsmUnrolls.inc
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2022 Intel Corporation
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11 #define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
12 
13 template <bool isARowMajor = true>
15  EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
16  else return i + j * LDA;
17 }
18 
57 namespace unrolls {
58 
59 template <int64_t N>
60 EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
61  EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
62  else EIGEN_IF_CONSTEXPR(N == 8) {
63  return 0xFF >> (8 - m);
64  }
65  else EIGEN_IF_CONSTEXPR(N == 4) {
66  return 0x0F >> (4 - m);
67  }
68  return 0;
69 }
70 
71 template <typename Packet>
72 EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
73 
74 template <>
75 EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
76  __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77  __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78  __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79  __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80  __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81  __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82  __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83  __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
84 
85  kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86  kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87  kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88  kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89  kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90  kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91  kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92  kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
93 
94  T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95  T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96  T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97  T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98  T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99  T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100  T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101  T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102  T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103  T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104  T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105  T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106  T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107  T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108  T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109  T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
110 
111  kernel.packet[0] = T0;
112  kernel.packet[1] = T1;
113  kernel.packet[2] = T2;
114  kernel.packet[3] = T3;
115  kernel.packet[4] = T4;
116  kernel.packet[5] = T5;
117  kernel.packet[6] = T6;
118  kernel.packet[7] = T7;
119 }
120 
121 template <>
122 EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
123  ptranspose(kernel);
124 }
125 
126 
129 template <typename Scalar>
130 class trans {
131  public:
132  using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
133  using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
134  static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
135 
136 
151  template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
152  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153  Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
154  constexpr int64_t counterReverse = endN - counter;
155  constexpr int64_t startN = counterReverse;
156 
158  EIGEN_IF_CONSTEXPR(remM) {
159  pstoreu<Scalar>(
160  C_arr + LDC * startN,
161  padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162  preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163  remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164  remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
165  }
166  else {
167  pstoreu<Scalar>(C_arr + LDC * startN,
168  padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
169  preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
170  }
171  }
172  else { // This block is only needed for fp32 case
173  // Reinterpret as __m512 for _mm512_shuffle_f32x4
174  vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
175  zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
176  // Swap lower and upper half of avx register.
177  zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
178  preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
179 
180  EIGEN_IF_CONSTEXPR(remM) {
181  pstoreu<Scalar>(
182  C_arr + LDC * startN,
183  padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184  preinterpret<vecHalf>(
185  zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
186  remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
187  }
188  else {
189  pstoreu<Scalar>(
190  C_arr + LDC * startN,
191  padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
192  preinterpret<vecHalf>(
193  zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
194  }
195  }
196  aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
197  }
198 
199  template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
200  static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201  Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
202  EIGEN_UNUSED_VARIABLE(C_arr);
205  EIGEN_UNUSED_VARIABLE(remM_);
206  }
207 
208  template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
209  static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
210  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
211  int64_t remM_ = 0) {
212  aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
213  }
214 
240  template <int64_t unrollN, int64_t packetIndexOffset>
241  static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
242  // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
243  // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
244  constexpr int64_t zmmStride = unrollN / PacketSize;
245  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
246  r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
247  r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
248  r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
249  r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
250  r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
251  r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
252  r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
253  r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
254  trans8x8blocks(r);
255  zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
256  zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
257  zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
258  zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
259  zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
260  zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
261  zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
262  zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
263  }
264 };
265 
280 template <typename Scalar>
281 class transB {
282  public:
283  using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
284  using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
285  static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
286 
287 
302  template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
303  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
304  Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
305  int64_t remM_ = 0) {
306  constexpr int64_t counterReverse = endN - counter;
307  constexpr int64_t startN = counterReverse;
308 
309  EIGEN_IF_CONSTEXPR(remM) {
310  ymm.packet[packetIndexOffset + startN] =
311  ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
312  }
313  else {
314  EIGEN_IF_CONSTEXPR(remN_ == 0) {
315  ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
316  }
317  else ymm.packet[packetIndexOffset + startN] =
318  ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
319  }
320 
321  aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
322  }
323 
324  template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
325  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
326  Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
327  int64_t remM_ = 0) {
328  EIGEN_UNUSED_VARIABLE(B_arr);
331  EIGEN_UNUSED_VARIABLE(remM_);
332  }
333 
340  template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
341  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
342  Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
343  constexpr int64_t counterReverse = endN - counter;
344  constexpr int64_t startN = counterReverse;
345 
346  EIGEN_IF_CONSTEXPR(remK || remM) {
347  pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
348  remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
349  }
350  else {
351  pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
352  }
353 
354  aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
355  }
356 
357  template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
358  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
359  Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
360  EIGEN_UNUSED_VARIABLE(B_arr);
363  EIGEN_UNUSED_VARIABLE(rem_);
364  }
365 
372  template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
373  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
374  Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
375  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
376  constexpr int64_t counterReverse = endN - counter;
377  constexpr int64_t startN = counterReverse;
378  transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
379  aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
380  }
381 
382  template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
383  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
384  Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
385  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
386  EIGEN_UNUSED_VARIABLE(B_arr);
388  EIGEN_UNUSED_VARIABLE(B_temp);
389  EIGEN_UNUSED_VARIABLE(LDB_);
391  EIGEN_UNUSED_VARIABLE(remM_);
392  }
393 
400  template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
401  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
402  Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
403  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
404  constexpr int64_t counterReverse = endN - counter;
405  constexpr int64_t startN = counterReverse;
406 
407  EIGEN_IF_CONSTEXPR(toTemp) {
408  transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
409  }
410  else {
411  transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
412  ymm, remM_);
413  }
414  aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
415  }
416 
417  template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
418  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
419  Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
420  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
421  EIGEN_UNUSED_VARIABLE(B_arr);
423  EIGEN_UNUSED_VARIABLE(B_temp);
424  EIGEN_UNUSED_VARIABLE(LDB_);
426  EIGEN_UNUSED_VARIABLE(remM_);
427  }
428 
429 
433  template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
434  static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
435  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
436  int64_t remM_ = 0) {
437  aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
438  }
439 
440  template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
441  static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
442  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
443  int64_t rem_ = 0) {
444  aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
445  }
446 
447  template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
448  static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
449  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
450  int64_t remM_ = 0) {
451  EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
452  else {
453  aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
454  }
455  }
456 
457  template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
458  static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
459  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
460  int64_t remM_ = 0) {
461  aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
462  }
463 
464  template <int64_t packetIndexOffset>
465  static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
466  // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
467  // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
468  PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
469  r.packet[0] = ymm.packet[packetIndexOffset + 0];
470  r.packet[1] = ymm.packet[packetIndexOffset + 1];
471  r.packet[2] = ymm.packet[packetIndexOffset + 2];
472  r.packet[3] = ymm.packet[packetIndexOffset + 3];
473  r.packet[4] = ymm.packet[packetIndexOffset + 4];
474  r.packet[5] = ymm.packet[packetIndexOffset + 5];
475  r.packet[6] = ymm.packet[packetIndexOffset + 6];
476  r.packet[7] = ymm.packet[packetIndexOffset + 7];
477  ptranspose(r);
478  ymm.packet[packetIndexOffset + 0] = r.packet[0];
479  ymm.packet[packetIndexOffset + 1] = r.packet[1];
480  ymm.packet[packetIndexOffset + 2] = r.packet[2];
481  ymm.packet[packetIndexOffset + 3] = r.packet[3];
482  ymm.packet[packetIndexOffset + 4] = r.packet[4];
483  ymm.packet[packetIndexOffset + 5] = r.packet[5];
484  ymm.packet[packetIndexOffset + 6] = r.packet[6];
485  ymm.packet[packetIndexOffset + 7] = r.packet[7];
486  }
487 
488  template <int64_t unrollN, bool toTemp, bool remM>
489  static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
490  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
491  int64_t remM_ = 0) {
492  constexpr int64_t U3 = PacketSize * 3;
493  constexpr int64_t U2 = PacketSize * 2;
494  constexpr int64_t U1 = PacketSize * 1;
502  EIGEN_IF_CONSTEXPR(unrollN == U3) {
503  // load LxU3 B col major, transpose LxU3 row major
504  constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
505  transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
506  transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
507  transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508  transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509  transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
510 
511  EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
512  transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
513  ymm, remM_);
514  transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
515  transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516  transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517  transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
518  ymm, remM_);
519  }
520  }
521  else EIGEN_IF_CONSTEXPR(unrollN == U2) {
522  // load LxU2 B col major, transpose LxU2 row major
523  constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
524  transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
525  transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
526  transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527  EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528  transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
529 
530  EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
531  transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
532  &B_temp[maxUBlock], LDB_, ymm, remM_);
533  transB::template transposeLxL<0>(ymm);
534  transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
535  &B_temp[maxUBlock], LDB_, ymm, remM_);
536  }
537  }
538  else EIGEN_IF_CONSTEXPR(unrollN == U1) {
539  // load LxU1 B col major, transpose LxU1 row major
540  transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
541  transB::template transposeLxL<0>(ymm);
542  EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
543  transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
544  }
545  else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
546  // load Lx4 B col major, transpose Lx4 row major
547  transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
548  transB::template transposeLxL<0>(ymm);
549  transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
550  }
551  else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
552  // load Lx4 B col major, transpose Lx4 row major
553  transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
554  transB::template transposeLxL<0>(ymm);
555  transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
556  }
557  else EIGEN_IF_CONSTEXPR(unrollN == 2) {
558  // load Lx2 B col major, transpose Lx2 row major
559  transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
560  transB::template transposeLxL<0>(ymm);
561  transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
562  }
563  else EIGEN_IF_CONSTEXPR(unrollN == 1) {
564  // load Lx1 B col major, transpose Lx1 row major
565  transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
566  transB::template transposeLxL<0>(ymm);
567  transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
568  }
569  }
570 };
571 
584 template <typename Scalar>
585 class trsm {
586  public:
587  using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
588  static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
589 
590 
605  template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
606  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
607  Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
608  constexpr int64_t counterReverse = endM * endK - counter;
609  constexpr int64_t startM = counterReverse / (endK);
610  constexpr int64_t startK = counterReverse % endK;
611 
612  constexpr int64_t packetIndex = startM * endK + startK;
613  constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
614  const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
615  EIGEN_IF_CONSTEXPR(krem) {
616  RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
617  }
618  else {
619  RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
620  }
621  aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
622  }
623 
624  template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
625  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
626  Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
627  EIGEN_UNUSED_VARIABLE(B_arr);
629  EIGEN_UNUSED_VARIABLE(RHSInPacket);
631  }
632 
640  template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
641  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
642  Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
643  constexpr int64_t counterReverse = endM * endK - counter;
644  constexpr int64_t startM = counterReverse / (endK);
645  constexpr int64_t startK = counterReverse % endK;
646 
647  constexpr int64_t packetIndex = startM * endK + startK;
648  constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
649  const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
650  EIGEN_IF_CONSTEXPR(krem) {
651  pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
652  }
653  else {
654  pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
655  }
656  aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
657  }
658 
659  template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
660  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
661  Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
662  EIGEN_UNUSED_VARIABLE(B_arr);
664  EIGEN_UNUSED_VARIABLE(RHSInPacket);
666  }
667 
676  template <int64_t currM, int64_t endK, int64_t counter>
677  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
678  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
679  constexpr int64_t counterReverse = endK - counter;
680  constexpr int64_t startK = counterReverse;
681 
682  constexpr int64_t packetIndex = currM * endK + startK;
683  RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
684  aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
685  }
686 
687  template <int64_t currM, int64_t endK, int64_t counter>
688  static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
689  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
690  EIGEN_UNUSED_VARIABLE(RHSInPacket);
691  EIGEN_UNUSED_VARIABLE(AInPacket);
692  }
693 
701  template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
702  int64_t counter, int64_t currentM>
703  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
704  Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
705  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
706  constexpr int64_t counterReverse = (endM - initM) * endK - counter;
707  constexpr int64_t startM = initM + counterReverse / (endK);
708  constexpr int64_t startK = counterReverse % endK;
709 
710  // For each row of A, first update all corresponding RHS
711  constexpr int64_t packetIndex = startM * endK + startK;
712  EIGEN_IF_CONSTEXPR(currentM > 0) {
713  RHSInPacket.packet[packetIndex] =
714  pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
715  RHSInPacket.packet[packetIndex]);
716  }
717 
718  EIGEN_IF_CONSTEXPR(startK == endK - 1) {
719  // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
720  EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
721  // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
722  // This will be used in divRHSByDiag
723  EIGEN_IF_CONSTEXPR(isFWDSolve)
724  AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
725  else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
726  }
727  else {
728  // Broadcast next off diagonal element of A
729  EIGEN_IF_CONSTEXPR(isFWDSolve)
730  AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
731  else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
732  }
733  }
734 
735  aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
736  A_arr, LDA, RHSInPacket, AInPacket);
737  }
738 
739  template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
740  int64_t counter, int64_t currentM>
741  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
742  Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
743  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
744  EIGEN_UNUSED_VARIABLE(A_arr);
746  EIGEN_UNUSED_VARIABLE(RHSInPacket);
747  EIGEN_UNUSED_VARIABLE(AInPacket);
748  }
749 
756  template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
757  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
758  Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
759  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
760  constexpr int64_t counterReverse = endM - counter;
761  constexpr int64_t startM = counterReverse;
762 
763  constexpr int64_t currentM = startM;
764  // Divides the right-hand side in row startM, by digonal value of A
765  // broadcasted to AInPacket.packet[startM-1] in the previous iteration.
766  //
767  // Without "if constexpr" the compiler instantiates the case <-1, numK>
768  // this is handled with enable_if to prevent out-of-bound warnings
769  // from the compiler
770  EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
771  trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
772 
773  // After division, the rhs corresponding to subsequent rows of A can be partially updated
774  // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
775  // to be used in the next iteration.
776  trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
777  AInPacket);
778 
779  // Handle division for the RHS corresponding to the final row of A.
780  EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
781  trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
782 
783  aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
784  AInPacket);
785  }
786 
787  template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
788  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
789  Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
790  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
791  EIGEN_UNUSED_VARIABLE(A_arr);
793  EIGEN_UNUSED_VARIABLE(RHSInPacket);
794  EIGEN_UNUSED_VARIABLE(AInPacket);
795  }
796 
797 
805  template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
806  static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
807  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
808  aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
809  }
810 
815  template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
816  static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
817  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
818  aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
819  }
820 
824  template <int64_t currM, int64_t endK>
825  static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
826  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
827  aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
828  }
829 
834  template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
835  int64_t currentM>
836  static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
837  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
838  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
839  aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
840  A_arr, LDA, RHSInPacket, AInPacket);
841  }
842 
849  template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
850  static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
851  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
852  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
853  static_assert(numK >= 1 && numK <= 3, "numK out of range");
854  aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
855  }
856 };
857 
863 template <typename Scalar, bool isAdd>
864 class gemm {
865  public:
866  using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
867  static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
868 
869 
885  template <int64_t endM, int64_t endN, int64_t counter>
886  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
887  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
888  constexpr int64_t counterReverse = endM * endN - counter;
889  constexpr int64_t startM = counterReverse / (endN);
890  constexpr int64_t startN = counterReverse % endN;
891 
892  zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
893  aux_setzero<endM, endN, counter - 1>(zmm);
894  }
895 
896  template <int64_t endM, int64_t endN, int64_t counter>
897  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
898  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
900  }
901 
909  template <int64_t endM, int64_t endN, int64_t counter, bool rem>
910  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
911  Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
912  EIGEN_UNUSED_VARIABLE(rem_);
913  constexpr int64_t counterReverse = endM * endN - counter;
914  constexpr int64_t startM = counterReverse / (endN);
915  constexpr int64_t startN = counterReverse % endN;
916 
917  EIGEN_IF_CONSTEXPR(rem)
918  zmm.packet[startN * endM + startM] =
919  padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
920  zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
921  else zmm.packet[startN * endM + startM] =
922  padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
923  aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
924  }
925 
926  template <int64_t endM, int64_t endN, int64_t counter, bool rem>
927  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
928  Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
929  EIGEN_UNUSED_VARIABLE(C_arr);
932  EIGEN_UNUSED_VARIABLE(rem_);
933  }
934 
942  template <int64_t endM, int64_t endN, int64_t counter, bool rem>
943  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
944  Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
945  EIGEN_UNUSED_VARIABLE(rem_);
946  constexpr int64_t counterReverse = endM * endN - counter;
947  constexpr int64_t startM = counterReverse / (endN);
948  constexpr int64_t startN = counterReverse % endN;
949 
950  EIGEN_IF_CONSTEXPR(rem)
951  pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
952  remMask<PacketSize>(rem_));
953  else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
954  aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
955  }
956 
957  template <int64_t endM, int64_t endN, int64_t counter, bool rem>
958  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
959  Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
960  EIGEN_UNUSED_VARIABLE(C_arr);
963  EIGEN_UNUSED_VARIABLE(rem_);
964  }
965 
972  template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
973  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
974  Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
975  EIGEN_UNUSED_VARIABLE(rem_);
976  constexpr int64_t counterReverse = endL - counter;
977  constexpr int64_t startL = counterReverse;
978 
979  EIGEN_IF_CONSTEXPR(rem)
980  zmm.packet[unrollM * unrollN + startL] =
981  ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
982  else zmm.packet[unrollM * unrollN + startL] =
983  ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
984 
985  aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
986  }
987 
988  template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
989  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
990  Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
994  EIGEN_UNUSED_VARIABLE(rem_);
995  }
996 
1003  template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1004  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
1005  Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1006  constexpr int64_t counterReverse = endB - counter;
1007  constexpr int64_t startB = counterReverse;
1008 
1009  zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1010 
1011  aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1012  }
1013 
1014  template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1015  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
1016  Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1017  EIGEN_UNUSED_VARIABLE(A_t);
1018  EIGEN_UNUSED_VARIABLE(LDA);
1019  EIGEN_UNUSED_VARIABLE(zmm);
1020  }
1021 
1029  template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1030  int64_t numBCast, bool rem>
1031  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
1032  Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1033  EIGEN_UNUSED_VARIABLE(rem_);
1034  if ((numLoad / endM + currK < unrollK)) {
1035  constexpr int64_t counterReverse = endM - counter;
1036  constexpr int64_t startM = counterReverse;
1037 
1038  EIGEN_IF_CONSTEXPR(rem) {
1039  zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1040  ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1041  }
1042  else {
1043  zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1044  ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1045  }
1046 
1047  aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1048  }
1049  }
1050 
1051  template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1052  int64_t numBCast, bool rem>
1053  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
1054  Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1055  EIGEN_UNUSED_VARIABLE(B_t);
1056  EIGEN_UNUSED_VARIABLE(LDB);
1057  EIGEN_UNUSED_VARIABLE(zmm);
1058  EIGEN_UNUSED_VARIABLE(rem_);
1059  }
1060 
1069  template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1070  int64_t numBCast, bool rem>
1071  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
1072  Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1073  int64_t rem_ = 0) {
1074  EIGEN_UNUSED_VARIABLE(rem_);
1075  constexpr int64_t counterReverse = endM * endN * endK - counter;
1076  constexpr int startK = counterReverse / (endM * endN);
1077  constexpr int startN = (counterReverse / (endM)) % endN;
1078  constexpr int startM = counterReverse % endM;
1079 
1080  EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
1081  gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1082  gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1083  }
1084 
1085  {
1086  // Interleave FMA and Bcast
1087  EIGEN_IF_CONSTEXPR(isAdd) {
1088  zmm.packet[startN * endM + startM] =
1089  pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1090  zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1091  }
1092  else {
1093  zmm.packet[startN * endM + startM] =
1094  pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1095  zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1096  }
1097  // Bcast
1098  EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1099  zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1100  (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1101  }
1102  }
1103 
1104  // We have updated all accumlators, time to load next set of B's
1105  EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
1106  gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1107  }
1108  aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1109  }
1110 
1111  template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1112  int64_t numBCast, bool rem>
1113  static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
1114  Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1115  int64_t rem_ = 0) {
1116  EIGEN_UNUSED_VARIABLE(B_t);
1117  EIGEN_UNUSED_VARIABLE(A_t);
1118  EIGEN_UNUSED_VARIABLE(LDB);
1119  EIGEN_UNUSED_VARIABLE(LDA);
1120  EIGEN_UNUSED_VARIABLE(zmm);
1121  EIGEN_UNUSED_VARIABLE(rem_);
1122  }
1123 
1124 
1128  template <int64_t endM, int64_t endN>
1129  static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1130  aux_setzero<endM, endN, endM * endN>(zmm);
1131  }
1132 
1136  template <int64_t endM, int64_t endN, bool rem = false>
1137  static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
1138  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1139  int64_t rem_ = 0) {
1140  EIGEN_UNUSED_VARIABLE(rem_);
1141  aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1142  }
1143 
1144  template <int64_t endM, int64_t endN, bool rem = false>
1145  static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
1146  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1147  int64_t rem_ = 0) {
1148  EIGEN_UNUSED_VARIABLE(rem_);
1149  aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1150  }
1151 
1155  template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
1156  static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
1157  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1158  int64_t rem_ = 0) {
1159  EIGEN_UNUSED_VARIABLE(rem_);
1160  aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1161  }
1162 
1166  template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
1167  static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
1168  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1169  aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1170  }
1171 
1175  template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
1176  static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
1177  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1178  int64_t rem_ = 0) {
1179  EIGEN_UNUSED_VARIABLE(rem_);
1180  aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1181  }
1182 
1206  template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
1207  bool rem = false>
1208  static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
1209  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1210  int64_t rem_ = 0) {
1211  EIGEN_UNUSED_VARIABLE(rem_);
1212  aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
1213  rem_);
1214  }
1215 };
1216 } // namespace unrolls
1217 
1218 #endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
Matrix3f m
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:836
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:957
#define EIGEN_IF_CONSTEXPR(X)
Definition: Macros.h:1298
EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf &a, std::complex< float > &b)
#define EIGEN_AVX_MAX_NUM_ROW
Definition: TrsmKernel.h:47
#define EIGEN_AVX_MAX_NUM_ACC
Definition: TrsmKernel.h:46
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:684
Packet8f pzero(const Packet8f &)
Packet8f vecHalfFloat
Definition: TrsmKernel.h:53
Packet8d vecFullDouble
Definition: TrsmKernel.h:52
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Packet pmul(const Packet &a, const Packet &b)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Packet pnmadd(const Packet &a, const Packet &b, const Packet &c)
Packet16f vecFullFloat
Definition: TrsmKernel.h:51
void gemm(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
std::int64_t int64_t
Definition: Meta.h:42
Definition: BFloat16.h:222
std::ptrdiff_t j