10 #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
11 #define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
13 #include "../../InternalHeaderCheck.h"
15 #if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
16 #define EIGEN_USE_AVX512_TRSM_KERNELS 1
21 #ifdef EIGEN_NO_MALLOC
22 #undef EIGEN_USE_AVX512_TRSM_KERNELS
23 #define EIGEN_USE_AVX512_TRSM_KERNELS 0
26 #if EIGEN_USE_AVX512_TRSM_KERNELS
27 #if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
28 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
30 #if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
31 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
34 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
35 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
46 #define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
47 #define EIGEN_AVX_MAX_NUM_ROW (int64_t(8))
48 #define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
49 #define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
50 #define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
60 #if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
77 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
78 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
81 #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
83 #if EIGEN_USE_AVX512_TRSM_R_KERNELS
84 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
85 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
89 #if EIGEN_USE_AVX512_TRSM_L_KERNELS
90 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
91 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
96 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
97 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
100 template <
typename Scalar>
111 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
112 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
113 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
119 template <
typename Scalar,
typename vec,
int64_t unrollM,
int64_t unrollN,
bool remM,
bool remN>
124 using urolls = unrolls::trans<Scalar>;
126 constexpr
int64_t U3 = urolls::PacketSize * 3;
127 constexpr
int64_t U2 = urolls::PacketSize * 2;
128 constexpr
int64_t U1 = urolls::PacketSize * 1;
130 static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3,
"unrollN should be a multiple of PacketSize");
133 urolls::template transpose<unrollN, 0>(zmm);
137 static_assert((remN && unrollN == U1) || !remN,
"When handling N remainder set unrollN=U1");
139 urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
142 urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
146 urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
156 urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
157 else if (remN_ == 14)
158 urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
159 else if (remN_ == 13)
160 urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
161 else if (remN_ == 12)
162 urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
163 else if (remN_ == 11)
164 urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
165 else if (remN_ == 10)
166 urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
168 urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
170 urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
172 urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
174 urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
176 urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
178 urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
180 urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
182 urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
184 urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
188 urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
190 urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
192 urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
194 urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
196 urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
198 urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
200 urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
219 template <
typename Scalar,
bool isARowMajor,
bool isCRowMajor,
bool isAdd,
bool handleKRem>
222 using urolls = unrolls::gemm<Scalar, isAdd>;
223 constexpr
int64_t U3 = urolls::PacketSize * 3;
224 constexpr
int64_t U2 = urolls::PacketSize * 2;
225 constexpr
int64_t U1 = urolls::PacketSize * 1;
231 for (;
j < N_;
j += U3) {
235 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)], *B_t = &B_arr[0 * LDB +
j];
236 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
237 urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
246 for (
int64_t k = K_; k < K; k++) {
255 urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
256 urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
259 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[
i +
j * LDC], LDC);
263 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
264 Scalar *B_t = &B_arr[0 * LDB +
j];
265 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
266 urolls::template setzero<3, 4>(zmm);
275 for (
int64_t k = K_; k < K; k++) {
276 urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
277 B_t, A_t, LDB, LDA, zmm);
284 urolls::template updateC<3, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
285 urolls::template storeC<3, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
288 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 4);
293 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
294 Scalar *B_t = &B_arr[0 * LDB +
j];
295 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
296 urolls::template setzero<3, 2>(zmm);
305 for (
int64_t k = K_; k < K; k++) {
306 urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
307 B_t, A_t, LDB, LDA, zmm);
314 urolls::template updateC<3, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
315 urolls::template storeC<3, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
318 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 2);
323 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
324 Scalar *B_t = &B_arr[0 * LDB +
j];
325 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
326 urolls::template setzero<3, 1>(zmm);
329 urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
330 B_t, A_t, LDB, LDA, zmm);
336 for (
int64_t k = K_; k < K; k++) {
337 urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
344 urolls::template updateC<3, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
345 urolls::template storeC<3, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
348 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 1);
357 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)], *B_t = &B_arr[0 * LDB +
j];
359 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
360 urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
369 for (
int64_t k = K_; k < K; k++) {
378 urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
379 urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
382 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[
i +
j * LDC], LDC);
386 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
387 Scalar *B_t = &B_arr[0 * LDB +
j];
388 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
389 urolls::template setzero<2, 4>(zmm);
398 for (
int64_t k = K_; k < K; k++) {
399 urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
407 urolls::template updateC<2, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
408 urolls::template storeC<2, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
411 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 4);
416 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
417 Scalar *B_t = &B_arr[0 * LDB +
j];
418 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
419 urolls::template setzero<2, 2>(zmm);
428 for (
int64_t k = K_; k < K; k++) {
429 urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
437 urolls::template updateC<2, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
438 urolls::template storeC<2, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
441 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 2);
446 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
447 Scalar *B_t = &B_arr[0 * LDB +
j];
448 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
449 urolls::template setzero<2, 1>(zmm);
451 urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
458 for (
int64_t k = K_; k < K; k++) {
459 urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
466 urolls::template updateC<2, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
467 urolls::template storeC<2, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
470 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 1);
479 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)], *B_t = &B_arr[0 * LDB +
j];
480 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
481 urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
490 for (
int64_t k = K_; k < K; k++) {
499 urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
500 urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
503 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[
i +
j * LDC], LDC);
507 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
508 Scalar *B_t = &B_arr[0 * LDB +
j];
509 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
510 urolls::template setzero<1, 4>(zmm);
519 for (
int64_t k = K_; k < K; k++) {
520 urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
528 urolls::template updateC<1, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
529 urolls::template storeC<1, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
532 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 4);
537 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
538 Scalar *B_t = &B_arr[0 * LDB +
j];
539 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
540 urolls::template setzero<1, 2>(zmm);
549 for (
int64_t k = K_; k < K; k++) {
550 urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
558 urolls::template updateC<1, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
559 urolls::template storeC<1, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
562 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 2);
567 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
568 Scalar *B_t = &B_arr[0 * LDB +
j];
569 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
570 urolls::template setzero<1, 1>(zmm);
573 urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
580 for (
int64_t k = K_; k < K; k++) {
581 urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
588 urolls::template updateC<1, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
589 urolls::template storeC<1, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
592 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 1);
602 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
603 Scalar *B_t = &B_arr[0 * LDB +
j];
604 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
605 urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
614 for (
int64_t k = K_; k < K; k++) {
623 urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
624 urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
627 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[
i +
j * LDC], LDC, 0, N -
j);
631 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
632 Scalar *B_t = &B_arr[0 * LDB +
j];
633 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
634 urolls::template setzero<1, 4>(zmm);
643 for (
int64_t k = K_; k < K; k++) {
644 urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
645 B_t, A_t, LDB, LDA, zmm, N -
j);
652 urolls::template updateC<1, 4, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
653 urolls::template storeC<1, 4, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
656 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[
i +
j * LDC], LDC, 4, N -
j);
661 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
662 Scalar *B_t = &B_arr[0 * LDB +
j];
663 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
664 urolls::template setzero<1, 2>(zmm);
673 for (
int64_t k = K_; k < K; k++) {
674 urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
675 B_t, A_t, LDB, LDA, zmm, N -
j);
682 urolls::template updateC<1, 2, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
683 urolls::template storeC<1, 2, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
686 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[
i +
j * LDC], LDC, 2, N -
j);
691 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
692 Scalar *B_t = &B_arr[0 * LDB +
j];
693 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
694 urolls::template setzero<1, 1>(zmm);
696 urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
697 B_t, A_t, LDB, LDA, zmm, N -
j);
703 for (
int64_t k = K_; k < K; k++) {
704 urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
712 urolls::template updateC<1, 1, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
713 urolls::template storeC<1, 1, true>(&C_arr[
i * LDC +
j], LDC, zmm, N -
j);
716 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[
i +
j * LDC], LDC, 1, N -
j);
730 template <
typename Scalar,
typename vec,
int64_t unrollM,
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag>
733 using urolls = unrolls::trsm<Scalar>;
734 constexpr
int64_t U3 = urolls::PacketSize * 3;
735 constexpr
int64_t U2 = urolls::PacketSize * 2;
736 constexpr
int64_t U1 = urolls::PacketSize * 1;
738 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
739 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> AInPacket;
742 while (K - k >= U3) {
743 urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
744 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
746 urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
750 urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
751 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
753 urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
757 urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
758 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
760 urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
765 urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
766 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
768 urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
780 template <
typename Scalar,
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag>
786 triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
788 triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
790 triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
792 triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
794 triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
796 triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
798 triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
800 triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
811 template <
typename Scalar,
bool toTemp = true,
bool remM = false>
815 using urolls = unrolls::transB<Scalar>;
817 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
818 constexpr
int64_t U3 = urolls::PacketSize * 3;
819 constexpr
int64_t U2 = urolls::PacketSize * 2;
820 constexpr
int64_t U1 = urolls::PacketSize * 1;
824 for (; k < K_; k += U3) {
825 urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
829 urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
834 urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
843 urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
853 urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
859 urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
864 urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
897 template <
typename Scalar,
bool isARowMajor =
true,
bool isBRowMajor =
true,
bool isFWDSolve =
true,
898 bool isUnitDiag =
false>
914 constexpr
int64_t kB = (3 * psize) * 5;
918 Scalar *B_temp = NULL;
925 sizeBTemp = (((
std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
930 for (
int64_t k = 0; k < numRHS; k += kB) {
931 int64_t bK = numRHS - k > kB ? kB : numRHS - k;
937 const int64_t numScalarPerCache = 64 /
sizeof(Scalar);
939 int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
946 int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
948 copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
950 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
951 &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2,
EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
953 copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
959 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
980 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
981 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
997 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
998 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
999 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1010 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1011 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1012 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1023 int64_t indA_i = isFWDSolve ? M_ : 0;
1024 int64_t indA_j = isFWDSolve ? 0 : bM;
1025 int64_t indB_i = isFWDSolve ? 0 : bM;
1026 int64_t indB_i2 = isFWDSolve ? M_ : 0;
1027 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1028 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
1029 bK, M_, LDA, LDB, LDB);
1032 int64_t indA_i = isFWDSolve ? M_ : 0;
1033 int64_t indA_j = isFWDSolve ? gemmOff : bM;
1034 int64_t indB_i = isFWDSolve ? M_ : 0;
1035 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1036 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
1037 B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
1038 M_ - gemmOff, LDA, LDT, LDB);
1042 int64_t indA_i = isFWDSolve ? M_ :
M - 1 - M_;
1043 int64_t indB_i = isFWDSolve ? M_ : 0;
1044 int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
1045 copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1046 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
1047 B_temp + offB_1, bM, bkL, LDA, bkL);
1048 copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1052 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(
ind,
ind, LDA)],
1053 B_arr + k +
ind * LDB, bM, bK, LDA, LDB);
1062 #if (EIGEN_USE_AVX512_TRSM_KERNELS)
1063 #if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1064 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide,
bool Specialized>
1067 template <
typename Index,
int Mode,
int TriStorageOrder>
1068 struct trsmKernelR<float,
Index, Mode, false, TriStorageOrder, 1, true> {
1069 static void kernel(
Index size,
Index otherSize,
const float *_tri,
Index triStride,
float *_other,
Index otherIncr,
1073 template <
typename Index,
int Mode,
int TriStorageOrder>
1074 struct trsmKernelR<double,
Index, Mode, false, TriStorageOrder, 1, true> {
1075 static void kernel(
Index size,
Index otherSize,
const double *_tri,
Index triStride,
double *_other,
Index otherIncr,
1079 template <
typename Index,
int Mode,
int TriStorageOrder>
1080 EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1082 Index otherStride) {
1084 #ifdef EIGEN_NO_RUNTIME_MALLOC
1085 if (!is_malloc_allowed()) {
1086 trsmKernelR<float,
Index, Mode,
false, TriStorageOrder, 1,
false>::kernel(
1087 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1091 triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1092 const_cast<float *
>(_tri), _other,
size, otherSize, triStride, otherStride);
1095 template <
typename Index,
int Mode,
int TriStorageOrder>
1096 EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1098 Index otherStride) {
1100 #ifdef EIGEN_NO_RUNTIME_MALLOC
1101 if (!is_malloc_allowed()) {
1102 trsmKernelR<double,
Index, Mode,
false, TriStorageOrder, 1,
false>::kernel(
1103 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1107 triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1108 const_cast<double *
>(_tri), _other,
size, otherSize, triStride, otherStride);
1113 #if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
1114 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide,
bool Specialized = true>
1117 template <
typename Index,
int Mode,
int TriStorageOrder>
1118 struct trsmKernelL<float,
Index, Mode, false, TriStorageOrder, 1, true> {
1119 static void kernel(
Index size,
Index otherSize,
const float *_tri,
Index triStride,
float *_other,
Index otherIncr,
1123 template <
typename Index,
int Mode,
int TriStorageOrder>
1124 struct trsmKernelL<double,
Index, Mode, false, TriStorageOrder, 1, true> {
1125 static void kernel(
Index size,
Index otherSize,
const double *_tri,
Index triStride,
double *_other,
Index otherIncr,
1129 template <
typename Index,
int Mode,
int TriStorageOrder>
1130 EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1132 Index otherStride) {
1134 #ifdef EIGEN_NO_RUNTIME_MALLOC
1135 if (!is_malloc_allowed()) {
1136 trsmKernelL<float,
Index, Mode,
false, TriStorageOrder, 1,
false>::kernel(
1137 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1141 triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1142 const_cast<float *
>(_tri), _other,
size, otherSize, triStride, otherStride);
1145 template <
typename Index,
int Mode,
int TriStorageOrder>
1146 EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1148 Index otherStride) {
1150 #ifdef EIGEN_NO_RUNTIME_MALLOC
1151 if (!is_malloc_allowed()) {
1152 trsmKernelL<double,
Index, Mode,
false, TriStorageOrder, 1,
false>::kernel(
1153 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1157 triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1158 const_cast<double *
>(_tri), _other,
size, otherSize, triStride, otherStride);
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_DONT_INLINE
#define EIGEN_IF_CONSTEXPR(X)
#define EIGEN_AVX_MAX_K_UNROL
#define EIGEN_AVX_MAX_NUM_ROW
#define EIGEN_AVX_B_LOAD_SETS
#define EIGEN_AVX_MAX_A_BCAST
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
void handmade_aligned_free(void *ptr)
EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB)
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC)
EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock< vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS > &zmm, Scalar *C_arr, int64_t LDC, int64_t remM_=0, int64_t remN_=0)
EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_, int64_t remM_=0)
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB)
void * handmade_aligned_malloc(std::size_t size, std::size_t alignment=EIGEN_DEFAULT_ALIGN_BYTES)
void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.