TrsmKernel.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2022 Intel Corporation
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
11 #define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
12 
13 #include "../../InternalHeaderCheck.h"
14 
15 #if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
16 #define EIGEN_USE_AVX512_TRSM_KERNELS 1
17 #endif
18 
19 // TRSM kernels currently unconditionally rely on malloc with AVX512.
20 // Disable them if malloc is explicitly disabled at compile-time.
21 #ifdef EIGEN_NO_MALLOC
22 #undef EIGEN_USE_AVX512_TRSM_KERNELS
23 #define EIGEN_USE_AVX512_TRSM_KERNELS 0
24 #endif
25 
26 #if EIGEN_USE_AVX512_TRSM_KERNELS
27 #if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
28 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
29 #endif
30 #if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
31 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
32 #endif
33 #else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
34 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
35 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
36 #endif
37 
38 // Need this for some std::min calls.
39 #ifdef min
40 #undef min
41 #endif
42 
43 namespace Eigen {
44 namespace internal {
45 
46 #define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
47 #define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
48 #define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
49 #define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
50 #define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
55 
56 // Compile-time unrolls are implemented here.
57 // Note: this depends on macros and typedefs above.
58 #include "TrsmUnrolls.inc"
59 
60 #if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
77 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
78 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
79 #endif
80 
81 #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
82 
83 #if EIGEN_USE_AVX512_TRSM_R_KERNELS
84 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
85 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
86 #endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
87 #endif
88 
89 #if EIGEN_USE_AVX512_TRSM_L_KERNELS
90 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
91 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
92 #endif
93 #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
94 
95 #else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
96 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
97 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
98 #endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
99 
100 template <typename Scalar>
101 int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
102  const int64_t U3 = 3 * packet_traits<Scalar>::size;
103  const int64_t MaxNb = 5 * U3;
104  int64_t Nb = std::min(MaxNb, N);
105  double cutoff_d =
106  (((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
107  int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
108  return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
109 }
110 #else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
111 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
112 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
113 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
114 #endif
115 
119 template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
120 EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
121  Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
122  EIGEN_UNUSED_VARIABLE(remN_);
123  EIGEN_UNUSED_VARIABLE(remM_);
124  using urolls = unrolls::trans<Scalar>;
125 
126  constexpr int64_t U3 = urolls::PacketSize * 3;
127  constexpr int64_t U2 = urolls::PacketSize * 2;
128  constexpr int64_t U1 = urolls::PacketSize * 1;
129 
130  static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
131  static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
132 
133  urolls::template transpose<unrollN, 0>(zmm);
134  EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
135  EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
136 
137  static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
138  EIGEN_IF_CONSTEXPR(!remN) {
139  urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
140  EIGEN_IF_CONSTEXPR(unrollN > U1) {
141  constexpr int64_t unrollN_ = std::min(unrollN - U1, U1);
142  urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
143  }
144  EIGEN_IF_CONSTEXPR(unrollN > U2) {
145  constexpr int64_t unrollN_ = std::min(unrollN - U2, U1);
146  urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
147  }
148  }
149  else {
150  EIGEN_IF_CONSTEXPR((std::is_same<Scalar, float>::value)) {
151  // Note: without "if constexpr" this section of code will also be
152  // parsed by the compiler so each of the storeC will still be instantiated.
153  // We use enable_if in aux_storeC to set it to an empty function for
154  // these cases.
155  if (remN_ == 15)
156  urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
157  else if (remN_ == 14)
158  urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
159  else if (remN_ == 13)
160  urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
161  else if (remN_ == 12)
162  urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
163  else if (remN_ == 11)
164  urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
165  else if (remN_ == 10)
166  urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
167  else if (remN_ == 9)
168  urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
169  else if (remN_ == 8)
170  urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
171  else if (remN_ == 7)
172  urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
173  else if (remN_ == 6)
174  urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
175  else if (remN_ == 5)
176  urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
177  else if (remN_ == 4)
178  urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
179  else if (remN_ == 3)
180  urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
181  else if (remN_ == 2)
182  urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
183  else if (remN_ == 1)
184  urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
185  }
186  else {
187  if (remN_ == 7)
188  urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
189  else if (remN_ == 6)
190  urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
191  else if (remN_ == 5)
192  urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
193  else if (remN_ == 4)
194  urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
195  else if (remN_ == 3)
196  urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
197  else if (remN_ == 2)
198  urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
199  else if (remN_ == 1)
200  urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
201  }
202  }
203 }
204 
219 template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
220 void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
221  int64_t LDC) {
222  using urolls = unrolls::gemm<Scalar, isAdd>;
223  constexpr int64_t U3 = urolls::PacketSize * 3;
224  constexpr int64_t U2 = urolls::PacketSize * 2;
225  constexpr int64_t U1 = urolls::PacketSize * 1;
226  using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
227  int64_t N_ = (N / U3) * U3;
230  int64_t j = 0;
231  for (; j < N_; j += U3) {
232  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3;
233  int64_t i = 0;
234  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
235  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
236  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
237  urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
238  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
239  urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
240  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
241  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
242  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
243  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
244  }
245  EIGEN_IF_CONSTEXPR(handleKRem) {
246  for (int64_t k = K_; k < K; k++) {
247  urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
248  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
249  B_t += LDB;
250  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
251  else A_t += LDA;
252  }
253  }
254  EIGEN_IF_CONSTEXPR(isCRowMajor) {
255  urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
256  urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
257  }
258  else {
259  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[i + j * LDC], LDC);
260  }
261  }
262  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
263  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
264  Scalar *B_t = &B_arr[0 * LDB + j];
265  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
266  urolls::template setzero<3, 4>(zmm);
267  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
268  urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
269  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
270  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
271  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
272  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
273  }
274  EIGEN_IF_CONSTEXPR(handleKRem) {
275  for (int64_t k = K_; k < K; k++) {
276  urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
277  B_t, A_t, LDB, LDA, zmm);
278  B_t += LDB;
279  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
280  else A_t += LDA;
281  }
282  }
283  EIGEN_IF_CONSTEXPR(isCRowMajor) {
284  urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
285  urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
286  }
287  else {
288  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
289  }
290  i += 4;
291  }
292  if (M - i >= 2) {
293  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
294  Scalar *B_t = &B_arr[0 * LDB + j];
295  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
296  urolls::template setzero<3, 2>(zmm);
297  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
298  urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
299  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
300  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
301  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
302  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
303  }
304  EIGEN_IF_CONSTEXPR(handleKRem) {
305  for (int64_t k = K_; k < K; k++) {
306  urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
307  B_t, A_t, LDB, LDA, zmm);
308  B_t += LDB;
309  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
310  else A_t += LDA;
311  }
312  }
313  EIGEN_IF_CONSTEXPR(isCRowMajor) {
314  urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
315  urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
316  }
317  else {
318  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
319  }
320  i += 2;
321  }
322  if (M - i > 0) {
323  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
324  Scalar *B_t = &B_arr[0 * LDB + j];
325  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
326  urolls::template setzero<3, 1>(zmm);
327  {
328  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
329  urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
330  B_t, A_t, LDB, LDA, zmm);
331  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
332  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
333  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
334  }
335  EIGEN_IF_CONSTEXPR(handleKRem) {
336  for (int64_t k = K_; k < K; k++) {
337  urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
338  B_t += LDB;
339  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
340  else A_t += LDA;
341  }
342  }
343  EIGEN_IF_CONSTEXPR(isCRowMajor) {
344  urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
345  urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
346  }
347  else {
348  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
349  }
350  }
351  }
352  }
353  if (N - j >= U2) {
354  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2;
355  int64_t i = 0;
356  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
357  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
358  EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j];
359  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
360  urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
361  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
362  urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
363  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
364  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
365  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
366  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
367  }
368  EIGEN_IF_CONSTEXPR(handleKRem) {
369  for (int64_t k = K_; k < K; k++) {
370  urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
371  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
372  B_t += LDB;
373  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
374  else A_t += LDA;
375  }
376  }
377  EIGEN_IF_CONSTEXPR(isCRowMajor) {
378  urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
379  urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
380  }
381  else {
382  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[i + j * LDC], LDC);
383  }
384  }
385  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
386  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
387  Scalar *B_t = &B_arr[0 * LDB + j];
388  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
389  urolls::template setzero<2, 4>(zmm);
390  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
391  urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
392  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
393  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
394  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
395  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
396  }
397  EIGEN_IF_CONSTEXPR(handleKRem) {
398  for (int64_t k = K_; k < K; k++) {
399  urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
400  LDA, zmm);
401  B_t += LDB;
402  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
403  else A_t += LDA;
404  }
405  }
406  EIGEN_IF_CONSTEXPR(isCRowMajor) {
407  urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
408  urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
409  }
410  else {
411  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
412  }
413  i += 4;
414  }
415  if (M - i >= 2) {
416  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
417  Scalar *B_t = &B_arr[0 * LDB + j];
418  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
419  urolls::template setzero<2, 2>(zmm);
420  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
421  urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
422  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
423  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
424  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
425  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
426  }
427  EIGEN_IF_CONSTEXPR(handleKRem) {
428  for (int64_t k = K_; k < K; k++) {
429  urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
430  LDA, zmm);
431  B_t += LDB;
432  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
433  else A_t += LDA;
434  }
435  }
436  EIGEN_IF_CONSTEXPR(isCRowMajor) {
437  urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
438  urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
439  }
440  else {
441  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
442  }
443  i += 2;
444  }
445  if (M - i > 0) {
446  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
447  Scalar *B_t = &B_arr[0 * LDB + j];
448  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
449  urolls::template setzero<2, 1>(zmm);
450  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
451  urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
452  LDA, zmm);
453  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
454  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
455  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
456  }
457  EIGEN_IF_CONSTEXPR(handleKRem) {
458  for (int64_t k = K_; k < K; k++) {
459  urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
460  B_t += LDB;
461  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
462  else A_t += LDA;
463  }
464  }
465  EIGEN_IF_CONSTEXPR(isCRowMajor) {
466  urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
467  urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
468  }
469  else {
470  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
471  }
472  }
473  j += U2;
474  }
475  if (N - j >= U1) {
476  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
477  int64_t i = 0;
478  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
479  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
480  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
481  urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
482  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
483  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
484  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
485  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
486  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
487  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
488  }
489  EIGEN_IF_CONSTEXPR(handleKRem) {
490  for (int64_t k = K_; k < K; k++) {
491  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
492  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
493  B_t += LDB;
494  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
495  else A_t += LDA;
496  }
497  }
498  EIGEN_IF_CONSTEXPR(isCRowMajor) {
499  urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
500  urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
501  }
502  else {
503  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[i + j * LDC], LDC);
504  }
505  }
506  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
507  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
508  Scalar *B_t = &B_arr[0 * LDB + j];
509  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
510  urolls::template setzero<1, 4>(zmm);
511  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
512  urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
513  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
514  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
515  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
516  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
517  }
518  EIGEN_IF_CONSTEXPR(handleKRem) {
519  for (int64_t k = K_; k < K; k++) {
520  urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
521  LDA, zmm);
522  B_t += LDB;
523  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
524  else A_t += LDA;
525  }
526  }
527  EIGEN_IF_CONSTEXPR(isCRowMajor) {
528  urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
529  urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
530  }
531  else {
532  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
533  }
534  i += 4;
535  }
536  if (M - i >= 2) {
537  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
538  Scalar *B_t = &B_arr[0 * LDB + j];
539  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
540  urolls::template setzero<1, 2>(zmm);
541  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
542  urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
543  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
544  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
545  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
546  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
547  }
548  EIGEN_IF_CONSTEXPR(handleKRem) {
549  for (int64_t k = K_; k < K; k++) {
550  urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
551  LDA, zmm);
552  B_t += LDB;
553  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
554  else A_t += LDA;
555  }
556  }
557  EIGEN_IF_CONSTEXPR(isCRowMajor) {
558  urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
559  urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
560  }
561  else {
562  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
563  }
564  i += 2;
565  }
566  if (M - i > 0) {
567  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
568  Scalar *B_t = &B_arr[0 * LDB + j];
569  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
570  urolls::template setzero<1, 1>(zmm);
571  {
572  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
573  urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
574  LDA, zmm);
575  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
576  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
577  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
578  }
579  EIGEN_IF_CONSTEXPR(handleKRem) {
580  for (int64_t k = K_; k < K; k++) {
581  urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
582  B_t += LDB;
583  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
584  else A_t += LDA;
585  }
586  }
587  EIGEN_IF_CONSTEXPR(isCRowMajor) {
588  urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
589  urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
590  }
591  else {
592  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
593  }
594  }
595  }
596  j += U1;
597  }
598  if (N - j > 0) {
599  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
600  int64_t i = 0;
601  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
602  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
603  Scalar *B_t = &B_arr[0 * LDB + j];
604  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
605  urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
606  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
607  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
608  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
609  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
610  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
611  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
612  }
613  EIGEN_IF_CONSTEXPR(handleKRem) {
614  for (int64_t k = K_; k < K; k++) {
615  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
616  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
617  B_t += LDB;
618  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
619  else A_t += LDA;
620  }
621  }
622  EIGEN_IF_CONSTEXPR(isCRowMajor) {
623  urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
624  urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
625  }
626  else {
627  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[i + j * LDC], LDC, 0, N - j);
628  }
629  }
630  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
631  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
632  Scalar *B_t = &B_arr[0 * LDB + j];
633  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
634  urolls::template setzero<1, 4>(zmm);
635  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
636  urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
637  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
638  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
639  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
640  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
641  }
642  EIGEN_IF_CONSTEXPR(handleKRem) {
643  for (int64_t k = K_; k < K; k++) {
644  urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
645  B_t, A_t, LDB, LDA, zmm, N - j);
646  B_t += LDB;
647  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
648  else A_t += LDA;
649  }
650  }
651  EIGEN_IF_CONSTEXPR(isCRowMajor) {
652  urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
653  urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
654  }
655  else {
656  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 4, N - j);
657  }
658  i += 4;
659  }
660  if (M - i >= 2) {
661  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
662  Scalar *B_t = &B_arr[0 * LDB + j];
663  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
664  urolls::template setzero<1, 2>(zmm);
665  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
666  urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
667  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
668  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
669  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
670  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
671  }
672  EIGEN_IF_CONSTEXPR(handleKRem) {
673  for (int64_t k = K_; k < K; k++) {
674  urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
675  B_t, A_t, LDB, LDA, zmm, N - j);
676  B_t += LDB;
677  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
678  else A_t += LDA;
679  }
680  }
681  EIGEN_IF_CONSTEXPR(isCRowMajor) {
682  urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
683  urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
684  }
685  else {
686  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 2, N - j);
687  }
688  i += 2;
689  }
690  if (M - i > 0) {
691  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
692  Scalar *B_t = &B_arr[0 * LDB + j];
693  PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
694  urolls::template setzero<1, 1>(zmm);
695  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
696  urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
697  B_t, A_t, LDB, LDA, zmm, N - j);
698  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
699  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
700  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
701  }
702  EIGEN_IF_CONSTEXPR(handleKRem) {
703  for (int64_t k = K_; k < K; k++) {
704  urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
705  N - j);
706  B_t += LDB;
707  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
708  else A_t += LDA;
709  }
710  }
711  EIGEN_IF_CONSTEXPR(isCRowMajor) {
712  urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
713  urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
714  }
715  else {
716  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 1, N - j);
717  }
718  }
719  }
720 }
721 
730 template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
731 EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
732  static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
733  using urolls = unrolls::trsm<Scalar>;
734  constexpr int64_t U3 = urolls::PacketSize * 3;
735  constexpr int64_t U2 = urolls::PacketSize * 2;
736  constexpr int64_t U1 = urolls::PacketSize * 1;
737 
738  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
739  PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> AInPacket;
740 
741  int64_t k = 0;
742  while (K - k >= U3) {
743  urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
744  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
745  AInPacket);
746  urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
747  k += U3;
748  }
749  if (K - k >= U2) {
750  urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
751  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
752  AInPacket);
753  urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
754  k += U2;
755  }
756  if (K - k >= U1) {
757  urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
758  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
759  AInPacket);
760  urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
761  k += U1;
762  }
763  if (K - k > 0) {
764  // Handle remaining number of RHS
765  urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
766  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
767  AInPacket);
768  urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
769  }
770 }
771 
780 template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
781 void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) {
782  // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
783  // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
784  using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
785  if (M == 8)
786  triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
787  else if (M == 7)
788  triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
789  else if (M == 6)
790  triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
791  else if (M == 5)
792  triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
793  else if (M == 4)
794  triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
795  else if (M == 3)
796  triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
797  else if (M == 2)
798  triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
799  else if (M == 1)
800  triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
801  return;
802 }
803 
811 template <typename Scalar, bool toTemp = true, bool remM = false>
812 EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
813  int64_t remM_ = 0) {
814  EIGEN_UNUSED_VARIABLE(remM_);
815  using urolls = unrolls::transB<Scalar>;
816  using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
817  PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
818  constexpr int64_t U3 = urolls::PacketSize * 3;
819  constexpr int64_t U2 = urolls::PacketSize * 2;
820  constexpr int64_t U1 = urolls::PacketSize * 1;
821  int64_t K_ = K / U3 * U3;
822  int64_t k = 0;
823 
824  for (; k < K_; k += U3) {
825  urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
826  B_temp += U3;
827  }
828  if (K - k >= U2) {
829  urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
830  B_temp += U2;
831  k += U2;
832  }
833  if (K - k >= U1) {
834  urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
835  B_temp += U1;
836  k += U1;
837  }
838  EIGEN_IF_CONSTEXPR(U1 > 8) {
839  // Note: without "if constexpr" this section of code will also be
840  // parsed by the compiler so there is an additional check in {load/store}BBlock
841  // to make sure the counter is not non-negative.
842  if (K - k >= 8) {
843  urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
844  B_temp += 8;
845  k += 8;
846  }
847  }
848  EIGEN_IF_CONSTEXPR(U1 > 4) {
849  // Note: without "if constexpr" this section of code will also be
850  // parsed by the compiler so there is an additional check in {load/store}BBlock
851  // to make sure the counter is not non-negative.
852  if (K - k >= 4) {
853  urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
854  B_temp += 4;
855  k += 4;
856  }
857  }
858  if (K - k >= 2) {
859  urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
860  B_temp += 2;
861  k += 2;
862  }
863  if (K - k >= 1) {
864  urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
865  B_temp += 1;
866  k += 1;
867  }
868 }
869 
897 template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
898  bool isUnitDiag = false>
899 void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
900  constexpr int64_t psize = packet_traits<Scalar>::size;
914  constexpr int64_t kB = (3 * psize) * 5; // 5*U3
915  constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
916 
917  int64_t sizeBTemp = 0;
918  Scalar *B_temp = NULL;
919  EIGEN_IF_CONSTEXPR(!isBRowMajor) {
925  sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
926  }
927 
928  EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
929 
930  for (int64_t k = 0; k < numRHS; k += kB) {
931  int64_t bK = numRHS - k > kB ? kB : numRHS - k;
932  int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
933 
934  // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
935  // instead of bK RHS in triSolveKernelLxK.
937  const int64_t numScalarPerCache = 64 / sizeof(Scalar);
938  // Leading dimension of B_temp, will be a multiple of the cache line size.
939  int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
940  int64_t offsetBTemp = 0;
941  for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
942  EIGEN_IF_CONSTEXPR(!isBRowMajor) {
943  int64_t indA_i = isFWDSolve ? i : M - 1 - i;
944  int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
945  int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp;
946  int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
947  // Copy values from B to B_temp.
948  copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
949  // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
950  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
951  &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
952  // Copy values from B_temp back to B. B_temp will be reused in gemm call below.
953  copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
954 
955  offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT;
956  }
957  else {
958  int64_t ind = isFWDSolve ? i : M - 1 - i;
959  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
960  &A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
961  }
962  if (i + EIGEN_AVX_MAX_NUM_ROW < M_) {
975  EIGEN_IF_CONSTEXPR(isBRowMajor) {
976  int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
977  int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
978  int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
979  int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
980  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
981  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
982  EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
983  }
984  else {
985  if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
994  int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
995  int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
996  int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
997  int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
998  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
999  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1000  M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1001  offsetBTemp = 0;
1002  gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
1003  } else {
1007  int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1008  int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
1009  int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1010  int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1011  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1012  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1013  EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1014  }
1015  }
1016  }
1017  }
1018  // Handle M remainder..
1019  int64_t bM = M - M_;
1020  if (bM > 0) {
1021  if (M_ > 0) {
1022  EIGEN_IF_CONSTEXPR(isBRowMajor) {
1023  int64_t indA_i = isFWDSolve ? M_ : 0;
1024  int64_t indA_j = isFWDSolve ? 0 : bM;
1025  int64_t indB_i = isFWDSolve ? 0 : bM;
1026  int64_t indB_i2 = isFWDSolve ? M_ : 0;
1027  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1028  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
1029  bK, M_, LDA, LDB, LDB);
1030  }
1031  else {
1032  int64_t indA_i = isFWDSolve ? M_ : 0;
1033  int64_t indA_j = isFWDSolve ? gemmOff : bM;
1034  int64_t indB_i = isFWDSolve ? M_ : 0;
1035  int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1036  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
1037  B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
1038  M_ - gemmOff, LDA, LDT, LDB);
1039  }
1040  }
1041  EIGEN_IF_CONSTEXPR(!isBRowMajor) {
1042  int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
1043  int64_t indB_i = isFWDSolve ? M_ : 0;
1044  int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
1045  copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1046  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
1047  B_temp + offB_1, bM, bkL, LDA, bkL);
1048  copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1049  }
1050  else {
1051  int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
1052  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
1053  B_arr + k + ind * LDB, bM, bK, LDA, LDB);
1054  }
1055  }
1056  }
1057 
1058  EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
1059 }
1060 
1061 // Template specializations of trsmKernelL/R for float/double and inner strides of 1.
1062 #if (EIGEN_USE_AVX512_TRSM_KERNELS)
1063 #if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1064 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride, bool Specialized>
1065 struct trsmKernelR;
1066 
1067 template <typename Index, int Mode, int TriStorageOrder>
1068 struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
1069  static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1070  Index otherStride);
1071 };
1072 
1073 template <typename Index, int Mode, int TriStorageOrder>
1074 struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
1075  static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1076  Index otherStride);
1077 };
1078 
1079 template <typename Index, int Mode, int TriStorageOrder>
1080 EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1081  Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1082  Index otherStride) {
1083  EIGEN_UNUSED_VARIABLE(otherIncr);
1084 #ifdef EIGEN_NO_RUNTIME_MALLOC
1085  if (!is_malloc_allowed()) {
1086  trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1087  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1088  return;
1089  }
1090 #endif
1091  triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1092  const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1093 }
1094 
1095 template <typename Index, int Mode, int TriStorageOrder>
1096 EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1097  Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1098  Index otherStride) {
1099  EIGEN_UNUSED_VARIABLE(otherIncr);
1100 #ifdef EIGEN_NO_RUNTIME_MALLOC
1101  if (!is_malloc_allowed()) {
1102  trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1103  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1104  return;
1105  }
1106 #endif
1107  triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1108  const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1109 }
1110 #endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1111 
1112 // These trsm kernels require temporary memory allocation
1113 #if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
1114 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride, bool Specialized = true>
1115 struct trsmKernelL;
1116 
1117 template <typename Index, int Mode, int TriStorageOrder>
1118 struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
1119  static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1120  Index otherStride);
1121 };
1122 
1123 template <typename Index, int Mode, int TriStorageOrder>
1124 struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
1125  static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1126  Index otherStride);
1127 };
1128 
1129 template <typename Index, int Mode, int TriStorageOrder>
1130 EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1131  Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1132  Index otherStride) {
1133  EIGEN_UNUSED_VARIABLE(otherIncr);
1134 #ifdef EIGEN_NO_RUNTIME_MALLOC
1135  if (!is_malloc_allowed()) {
1136  trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1137  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1138  return;
1139  }
1140 #endif
1141  triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1142  const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1143 }
1144 
1145 template <typename Index, int Mode, int TriStorageOrder>
1146 EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1147  Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1148  Index otherStride) {
1149  EIGEN_UNUSED_VARIABLE(otherIncr);
1150 #ifdef EIGEN_NO_RUNTIME_MALLOC
1151  if (!is_malloc_allowed()) {
1152  trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1153  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1154  return;
1155  }
1156 #endif
1157  triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1158  const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1159 }
1160 #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
1161 #endif // EIGEN_USE_AVX512_TRSM_KERNELS
1162 } // namespace internal
1163 } // namespace Eigen
1164 #endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
Matrix4Xd M
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:836
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:957
#define EIGEN_DONT_INLINE
Definition: Macros.h:844
#define EIGEN_IF_CONSTEXPR(X)
Definition: Macros.h:1298
std::vector< int > ind
#define EIGEN_AVX_MAX_K_UNROL
Definition: TrsmKernel.h:48
#define EIGEN_AVX_MAX_NUM_ROW
Definition: TrsmKernel.h:47
#define EIGEN_AVX_B_LOAD_SETS
Definition: TrsmKernel.h:49
#define EIGEN_AVX_MAX_A_BCAST
Definition: TrsmKernel.h:50
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:684
void handmade_aligned_free(void *ptr)
Definition: Memory.h:165
Packet8f vecHalfFloat
Definition: TrsmKernel.h:53
Packet8d vecFullDouble
Definition: TrsmKernel.h:52
EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:731
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC)
Definition: TrsmKernel.h:220
EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock< vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS > &zmm, Scalar *C_arr, int64_t LDC, int64_t remM_=0, int64_t remN_=0)
Definition: TrsmKernel.h:120
EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_, int64_t remM_=0)
Definition: TrsmKernel.h:812
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:899
Packet4d vecHalfDouble
Definition: TrsmKernel.h:54
void * handmade_aligned_malloc(std::size_t size, std::size_t alignment=EIGEN_DEFAULT_ALIGN_BYTES)
Definition: Memory.h:150
Packet16f vecFullFloat
Definition: TrsmKernel.h:51
void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:781
std::int64_t int64_t
Definition: Meta.h:42
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
std::ptrdiff_t j