10 #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11 #define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
13 template <
bool isARowMajor = true>
16 else return i +
j * LDA;
63 return 0xFF >> (8 -
m);
66 return 0x0F >> (4 -
m);
71 template <
typename Packet>
76 __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77 __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78 __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79 __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80 __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81 __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82 __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83 __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
85 kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86 kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87 kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88 kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89 kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90 kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91 kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92 kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
94 T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95 T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96 T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97 T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98 T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99 T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100 T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101 T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102 T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103 T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104 T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105 T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106 T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107 T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108 T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109 T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
111 kernel.packet[0] = T0;
112 kernel.packet[1] = T1;
113 kernel.packet[2] = T2;
114 kernel.packet[3] = T3;
115 kernel.packet[4] = T4;
116 kernel.packet[5] = T5;
117 kernel.packet[6] = T6;
118 kernel.packet[7] = T7;
129 template <
typename Scalar>
151 template <
int64_t endN,
int64_t counter,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
152 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153 Scalar *C_arr,
int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t remM_ = 0) {
154 constexpr
int64_t counterReverse = endN - counter;
155 constexpr
int64_t startN = counterReverse;
160 C_arr + LDC * startN,
161 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
167 pstoreu<Scalar>(C_arr + LDC * startN,
168 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN),
169 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
174 vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
178 preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
182 C_arr + LDC * startN,
183 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184 preinterpret<vecHalf>(
186 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
190 C_arr + LDC * startN,
191 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN),
192 preinterpret<vecHalf>(
196 aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
199 template <
int64_t endN,
int64_t counter,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
200 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201 Scalar *C_arr,
int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t remM_ = 0) {
208 template <
int64_t endN,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
210 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
212 aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
240 template <
int64_t unrollN,
int64_t packetIndexOffset>
241 static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
244 constexpr
int64_t zmmStride = unrollN / PacketSize;
245 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
246 r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
247 r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
248 r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
249 r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
250 r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
251 r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
252 r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
253 r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
255 zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
256 zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
257 zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
258 zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
259 zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
260 zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
261 zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
262 zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
280 template <
typename Scalar>
302 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
304 Scalar *B_arr,
int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
306 constexpr
int64_t counterReverse = endN - counter;
307 constexpr
int64_t startN = counterReverse;
310 ymm.packet[packetIndexOffset + startN] =
311 ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
315 ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB]);
317 else ymm.packet[packetIndexOffset + startN] =
318 ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
321 aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
324 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
326 Scalar *B_arr,
int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
340 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remK,
bool remM>
342 Scalar *B_arr,
int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t rem_ = 0) {
343 constexpr
int64_t counterReverse = endN - counter;
344 constexpr
int64_t startN = counterReverse;
347 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
348 remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
351 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
354 aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
357 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remK,
bool remM>
359 Scalar *B_arr,
int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t rem_ = 0) {
372 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remN_>
375 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
376 constexpr
int64_t counterReverse = endN - counter;
377 constexpr
int64_t startN = counterReverse;
378 transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
379 aux_loadBBlock<endN, counter -
EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
382 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remN_>
385 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
400 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remK_>
403 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
404 constexpr
int64_t counterReverse = endN - counter;
405 constexpr
int64_t startN = counterReverse;
408 transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
411 transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
414 aux_storeBBlock<endN, counter -
EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
417 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remK_>
420 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
433 template <
int64_t endN,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
435 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
437 aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
440 template <
int64_t endN,
int64_t packetIndexOffset,
bool remK,
bool remM>
442 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
444 aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
447 template <
int64_t unrollN,
bool toTemp,
bool remM,
int64_t remN_ = 0>
449 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
451 EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
453 aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
457 template <
int64_t unrollN,
bool toTemp,
bool remM,
int64_t remK_>
459 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
461 aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
464 template <
int64_t packetIndexOffset>
465 static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
468 PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
469 r.packet[0] = ymm.packet[packetIndexOffset + 0];
470 r.packet[1] = ymm.packet[packetIndexOffset + 1];
471 r.packet[2] = ymm.packet[packetIndexOffset + 2];
472 r.packet[3] = ymm.packet[packetIndexOffset + 3];
473 r.packet[4] = ymm.packet[packetIndexOffset + 4];
474 r.packet[5] = ymm.packet[packetIndexOffset + 5];
475 r.packet[6] = ymm.packet[packetIndexOffset + 6];
476 r.packet[7] = ymm.packet[packetIndexOffset + 7];
478 ymm.packet[packetIndexOffset + 0] = r.packet[0];
479 ymm.packet[packetIndexOffset + 1] = r.packet[1];
480 ymm.packet[packetIndexOffset + 2] = r.packet[2];
481 ymm.packet[packetIndexOffset + 3] = r.packet[3];
482 ymm.packet[packetIndexOffset + 4] = r.packet[4];
483 ymm.packet[packetIndexOffset + 5] = r.packet[5];
484 ymm.packet[packetIndexOffset + 6] = r.packet[6];
485 ymm.packet[packetIndexOffset + 7] = r.packet[7];
488 template <
int64_t unrollN,
bool toTemp,
bool remM>
490 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
492 constexpr
int64_t U3 = PacketSize * 3;
493 constexpr
int64_t U2 = PacketSize * 2;
494 constexpr
int64_t U1 = PacketSize * 1;
505 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
506 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
507 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
512 transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
514 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
515 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
524 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
525 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
526 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527 EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
531 transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
532 &B_temp[maxUBlock], LDB_, ymm, remM_);
533 transB::template transposeLxL<0>(ymm);
534 transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
535 &B_temp[maxUBlock], LDB_, ymm, remM_);
540 transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
541 transB::template transposeLxL<0>(ymm);
543 transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
547 transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
548 transB::template transposeLxL<0>(ymm);
549 transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
553 transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
554 transB::template transposeLxL<0>(ymm);
555 transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
559 transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
560 transB::template transposeLxL<0>(ymm);
561 transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
565 transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
566 transB::template transposeLxL<0>(ymm);
567 transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
584 template <
typename Scalar>
605 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
607 Scalar *B_arr,
int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
int64_t rem = 0) {
608 constexpr
int64_t counterReverse = endM * endK - counter;
609 constexpr
int64_t startM = counterReverse / (endK);
610 constexpr
int64_t startK = counterReverse % endK;
612 constexpr
int64_t packetIndex = startM * endK + startK;
613 constexpr
int64_t startM_ = isFWDSolve ? startM : -startM;
614 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
616 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
619 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
621 aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
624 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
626 Scalar *B_arr,
int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
int64_t rem = 0) {
640 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
642 Scalar *B_arr,
int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
int64_t rem = 0) {
643 constexpr
int64_t counterReverse = endM * endK - counter;
644 constexpr
int64_t startM = counterReverse / (endK);
645 constexpr
int64_t startK = counterReverse % endK;
647 constexpr
int64_t packetIndex = startM * endK + startK;
648 constexpr
int64_t startM_ = isFWDSolve ? startM : -startM;
649 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
651 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
654 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
656 aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
659 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
661 Scalar *B_arr,
int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
int64_t rem = 0) {
676 template <
int64_t currM,
int64_t endK,
int64_t counter>
678 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
679 constexpr
int64_t counterReverse = endK - counter;
680 constexpr
int64_t startK = counterReverse;
682 constexpr
int64_t packetIndex = currM * endK + startK;
683 RHSInPacket.packet[packetIndex] =
pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
684 aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
687 template <
int64_t currM,
int64_t endK,
int64_t counter>
688 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
689 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
701 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t initM,
int64_t endM,
int64_t endK,
704 Scalar *A_arr,
int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
705 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
706 constexpr
int64_t counterReverse = (endM - initM) * endK - counter;
707 constexpr
int64_t startM = initM + counterReverse / (endK);
708 constexpr
int64_t startK = counterReverse % endK;
711 constexpr
int64_t packetIndex = startM * endK + startK;
713 RHSInPacket.packet[packetIndex] =
714 pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
715 RHSInPacket.packet[packetIndex]);
724 AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
725 else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
730 AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
731 else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
735 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
736 A_arr, LDA, RHSInPacket, AInPacket);
739 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t initM,
int64_t endM,
int64_t endK,
742 Scalar *A_arr,
int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
743 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
756 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t endM,
int64_t counter,
int64_t numK>
758 Scalar *A_arr,
int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
759 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
760 constexpr
int64_t counterReverse = endM - counter;
761 constexpr
int64_t startM = counterReverse;
763 constexpr
int64_t currentM = startM;
771 trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
776 trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
781 trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
783 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
787 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t endM,
int64_t counter,
int64_t numK>
805 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
bool krem = false>
807 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
int64_t rem = 0) {
808 aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
815 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
bool krem = false>
817 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
int64_t rem = 0) {
818 aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
824 template <
int64_t currM,
int64_t endK>
825 static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
826 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
827 aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
834 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t startM,
int64_t endM,
int64_t endK,
837 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
838 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
839 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
840 A_arr, LDA, RHSInPacket, AInPacket);
849 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t endM,
int64_t numK>
851 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
852 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
853 static_assert(numK >= 1 && numK <= 3,
"numK out of range");
854 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
863 template <
typename Scalar,
bool isAdd>
885 template <
int64_t endM,
int64_t endN,
int64_t counter>
887 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
888 constexpr
int64_t counterReverse = endM * endN - counter;
889 constexpr
int64_t startM = counterReverse / (endN);
890 constexpr
int64_t startN = counterReverse % endN;
892 zmm.packet[startN * endM + startM] =
pzero(zmm.packet[startN * endM + startM]);
893 aux_setzero<endM, endN, counter - 1>(zmm);
896 template <
int64_t endM,
int64_t endN,
int64_t counter>
898 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
909 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
911 Scalar *C_arr,
int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
913 constexpr
int64_t counterReverse = endM * endN - counter;
914 constexpr
int64_t startM = counterReverse / (endN);
915 constexpr
int64_t startN = counterReverse % endN;
918 zmm.packet[startN * endM + startM] =
919 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
920 zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
921 else zmm.packet[startN * endM + startM] =
922 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
923 aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
926 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
928 Scalar *C_arr,
int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
942 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
944 Scalar *C_arr,
int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
946 constexpr
int64_t counterReverse = endM * endN - counter;
947 constexpr
int64_t startM = counterReverse / (endN);
948 constexpr
int64_t startN = counterReverse % endN;
951 pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
952 remMask<PacketSize>(rem_));
953 else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
954 aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
957 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
959 Scalar *C_arr,
int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
972 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
int64_t counter,
bool rem>
974 Scalar *B_t,
int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
976 constexpr
int64_t counterReverse = endL - counter;
977 constexpr
int64_t startL = counterReverse;
980 zmm.packet[unrollM * unrollN + startL] =
981 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
982 else zmm.packet[unrollM * unrollN + startL] =
983 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
985 aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
988 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
int64_t counter,
bool rem>
990 Scalar *B_t,
int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
1003 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t counter,
int64_t numLoad>
1005 Scalar *A_t,
int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1006 constexpr
int64_t counterReverse = endB - counter;
1007 constexpr
int64_t startB = counterReverse;
1009 zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1011 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1014 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t counter,
int64_t numLoad>
1016 Scalar *A_t,
int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1032 Scalar *B_t,
int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
1034 if ((numLoad / endM + currK < unrollK)) {
1035 constexpr
int64_t counterReverse = endM - counter;
1036 constexpr
int64_t startM = counterReverse;
1039 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1040 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1043 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1044 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1047 aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1054 Scalar *B_t,
int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
1072 Scalar *B_t, Scalar *A_t,
int64_t LDB,
int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1075 constexpr
int64_t counterReverse = endM * endN * endK - counter;
1076 constexpr
int startK = counterReverse / (endM * endN);
1077 constexpr
int startN = (counterReverse / (endM)) % endN;
1078 constexpr
int startM = counterReverse % endM;
1081 gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1082 gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1088 zmm.packet[startN * endM + startM] =
1089 pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1090 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1093 zmm.packet[startN * endM + startM] =
1094 pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1095 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1098 EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1099 zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1100 (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1106 gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1108 aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1114 Scalar *B_t, Scalar *A_t,
int64_t LDB,
int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1128 template <
int64_t endM,
int64_t endN>
1129 static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1130 aux_setzero<endM, endN, endM * endN>(zmm);
1136 template <
int64_t endM,
int64_t endN,
bool rem = false>
1138 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1141 aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1144 template <
int64_t endM,
int64_t endN,
bool rem = false>
1146 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1149 aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1155 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
bool rem>
1157 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1160 aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1166 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t numLoad>
1168 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1169 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1175 template <
int64_t endM,
int64_t unrollN,
int64_t currK,
int64_t unrollK,
int64_t numLoad,
int64_t numBCast,
bool rem>
1177 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1180 aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1209 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1212 aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_IF_CONSTEXPR(X)
EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf &a, std::complex< float > &b)
#define EIGEN_AVX_MAX_NUM_ROW
#define EIGEN_AVX_MAX_NUM_ACC
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
Packet8f pzero(const Packet8f &)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Packet pmul(const Packet &a, const Packet &b)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Packet pnmadd(const Packet &a, const Packet &b, const Packet &c)
void gemm(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)