10 #ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
11 #define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
16 #include <x86intrin.h>
18 #include <immintrin.h>
19 #include <type_traits>
21 #include "../../InternalHeaderCheck.h"
23 #if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
24 #define EIGEN_USE_AVX512_GEMM_KERNELS 1
27 #define SECOND_FETCH (32)
28 #if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
31 #define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
37 template <
typename Scalar,
bool is_unit_inc>
39 using vec =
typename packet_traits<Scalar>::type;
40 using vec_ymm =
typename unpacket_traits<vec>::half;
41 using vec_xmm =
typename unpacket_traits<vec_ymm>::half;
42 using umask_t =
typename unpacket_traits<vec>::mask_t;
44 static constexpr
bool is_f32 =
sizeof(Scalar) ==
sizeof(
float);
45 static constexpr
bool is_f64 =
sizeof(Scalar) ==
sizeof(
double);
47 #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
48 static constexpr
bool use_less_a_regs = !is_unit_inc;
50 static constexpr
bool use_less_a_regs =
true;
52 #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
53 static constexpr
bool use_less_b_regs = !is_unit_inc;
55 static constexpr
bool use_less_b_regs =
true;
58 static constexpr
int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5};
59 static constexpr
int b_regs[] = {6, use_less_b_regs ? 6 : 7};
60 static constexpr
int c_regs[] = {
61 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
64 static constexpr
int alpha_load_reg = 0;
65 static constexpr
int c_load_regs[] = {1, 2, 6};
67 static constexpr
int a_shift = 128;
68 static constexpr
int b_shift = 128;
70 static constexpr
int nelems_in_cache_line = is_f32 ? 16 : 8;
71 static constexpr
int a_prefetch_size = nelems_in_cache_line * 2;
72 static constexpr
int b_prefetch_size = nelems_in_cache_line * 8;
89 const Index a_stride, b_stride;
90 const Index a_off, b_off;
95 _mm_prefetch((
char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
99 _mm_prefetch((
char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
102 EIGEN_ALWAYS_INLINE void prefetch_x(
const Scalar *x_addr) { _mm_prefetch((
char *)(x_addr - a_shift), _MM_HINT_T2); }
105 #if defined(__PRFCHW__) && __PRFCHW__ == 1
106 _m_prefetchw((
void *)c_addr);
108 _mm_prefetch((
char *)c_addr, _MM_HINT_T0);
112 template <
int nelems>
114 switch (nelems *
sizeof(*a_addr) * 8) {
117 a_reg = ploadu<vec>(a_addr);
120 a_reg = ploadu<vec>(a_addr);
123 a_reg = ploadu<vec>(a_addr);
126 a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(
ploadu<Packet4d>(
reinterpret_cast<const double *
>(a_addr))));
129 a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(
ploadu<Packet4f>(
reinterpret_cast<const float *
>(a_addr))));
132 a_reg = preinterpret<vec>(
pload1<Packet8d>(
reinterpret_cast<const double *
>(a_addr)));
135 a_reg = pload1<vec>(a_addr);
140 EIGEN_ALWAYS_INLINE void b_load(vec &b_reg,
const Scalar *b_addr) { b_reg = pload1<vec>(b_addr); }
142 template <
int nelems>
145 switch (nelems *
sizeof(*mem) * 8) {
157 pstoreu(mem, preinterpret<vec_ymm>(src));
160 pstoreu(mem, preinterpret<vec_xmm>(src));
163 pstorel(mem, preinterpret<vec_xmm>(src));
166 pstores(mem, preinterpret<vec_xmm>(src));
170 switch (nelems *
sizeof(*mem) * 8) {
197 template <
int nelems>
200 switch (nelems *
sizeof(*mem) * 8) {
203 dst =
padd(src, ploadu<vec>(mem));
206 dst =
padd(src, ploadu<vec>(mem));
209 dst =
padd(src, ploadu<vec>(mem));
212 dst = preinterpret<vec>(
padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
215 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
218 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
221 dst = preinterpret<vec>(
padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
228 switch (nelems *
sizeof(*mem) * 8) {
231 reg = pgather<Scalar, vec>(mem, inc);
232 dst =
padd(src, reg);
235 reg = pgather<Scalar, vec>(mem, inc);
236 dst =
padd(src, reg);
239 reg = pgather<Scalar, vec>(mem, inc);
240 dst =
padd(src, reg);
243 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
244 dst = preinterpret<vec>(
padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
247 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
248 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
252 reg =
pgather(reg, mem, inc, mask);
253 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
255 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
259 dst = preinterpret<vec>(
padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
265 EIGEN_STRONG_INLINE
void vfmadd(vec &dst,
const vec &src1,
const vec &src2) {
266 dst =
pmadd(src1, src2, dst);
268 #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
270 __asm__(
"#" : [dst]
"+v"(dst) : [src1]
"%v"(src1), [src2]
"v"(src2));
274 template <
int nelems>
275 EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst,
const Scalar *mem, vec &src, vec &scale, vec ®) {
277 switch (nelems *
sizeof(*mem) * 8) {
280 dst =
pmadd(scale, src, ploadu<vec>(mem));
283 dst =
pmadd(scale, src, ploadu<vec>(mem));
286 dst =
pmadd(scale, src, ploadu<vec>(mem));
290 preinterpret<vec>(
pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
294 preinterpret<vec>(
pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
298 preinterpret<vec>(
pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
302 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
309 switch (nelems *
sizeof(*mem) * 8) {
312 reg = pgather<Scalar, vec>(mem, inc);
313 dst =
pmadd(scale, src, reg);
316 reg = pgather<Scalar, vec>(mem, inc);
317 dst =
pmadd(scale, src, reg);
320 reg = pgather<Scalar, vec>(mem, inc);
321 dst =
pmadd(scale, src, reg);
324 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
325 dst = preinterpret<vec>(
326 pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
329 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
330 dst = preinterpret<vec>(
331 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
335 reg =
pgather(reg, mem, inc, mask);
336 dst = preinterpret<vec>(
337 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
339 dst = preinterpret<vec>(
340 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
345 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
351 template <
int j,
int endX,
int i,
int endY,
int nelems>
356 template <
int j,
int endX,
int i,
int endY,
int nelems>
360 auto &a_reg = zmm[a_regs[
i + (
j % 2) * 3]];
361 const Scalar *a_addr = ao + nelems *
j + nelems_in_cache_line *
i - a_shift;
362 a_load<nelems>(a_reg, a_addr);
364 a_loads<j, endX, i + 1, endY, nelems>(ao);
366 a_loads<j + 1, endX, 0, endY, nelems>(ao);
371 template <
int un,
int max_b_unroll,
int i,
int um_vecs,
int a_unroll,
int b_unroll>
372 EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (
i > um_vecs)> prefetch_cs(
const Scalar *co1,
392 template <
int un,
int max_b_unroll,
int i,
int um_vecs,
int a_unroll,
int b_unroll>
393 EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (
i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) {
394 if (un < max_b_unroll) {
395 if (b_unroll >= un + 1) {
396 if (un == 4 &&
i == 0) co2 = co1 + 4 * ldc;
399 Scalar *co = (un + 1 <= 4) ? co1 : co2;
400 auto co_off = (un % 4) * ldc + a_unroll - 1 +
i * nelems_in_cache_line *
sizeof *co;
401 prefetch_c(co + co_off);
403 prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
405 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
409 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
415 template <
int i,
int um_vecs,
int idx,
int nelems>
416 EIGEN_ALWAYS_INLINE std::enable_if_t<(
i > um_vecs)> scale_load_c(
const Scalar *cox, vec &alpha_reg) {
421 template <
int i,
int um_vecs,
int idx,
int nelems>
422 EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(
const Scalar *cox, vec &alpha_reg) {
424 auto &c_reg = zmm[c_regs[
i + idx * 3]];
425 auto &c_load_reg = zmm[c_load_regs[
i % 3]];
428 c_mem +=
i * nelems_in_cache_line;
430 c_mem +=
i * nelems_in_cache_line * inc;
432 if (!is_beta0 && is_alpha1)
433 vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
434 else if (!is_beta0 && !is_alpha1)
435 vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
436 else if (is_beta0 && !is_alpha1)
437 c_reg =
pmul(alpha_reg, c_reg);
439 scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
444 template <
int i,
int um_vecs,
int idx,
int nelems>
449 template <
int i,
int um_vecs,
int idx,
int nelems>
452 auto &c_reg = zmm[c_regs[
i + idx * 3]];
455 c_mem +=
i * nelems_in_cache_line;
457 c_mem +=
i * nelems_in_cache_line * inc;
459 c_store<nelems>(c_mem, c_reg);
460 c_reg =
pzero(c_reg);
462 write_c<i + 1, um_vecs, idx, nelems>(cox);
496 template <
int pow,
int a_unroll,
int idx>
498 if (
pow >= 4) cox += ldc;
500 const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
501 auto &alpha_reg = zmm[alpha_load_reg];
503 scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
504 write_c<0, um_vecs, idx, a_unroll>(cox);
507 template <
int pow,
int a_unroll>
509 constexpr
int idx =
pow / 2;
510 Scalar *&cox = idx == 0 ? co1 : co2;
512 constexpr
int max_count = (
pow + 1) / 2;
513 static_assert(max_count <= 4,
"Unsupported max_count.");
515 if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
516 if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
517 if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
518 if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
521 template <
int max_b_unroll,
int a_unroll,
int b_unroll>
523 auto &alpha_reg = zmm[alpha_load_reg];
526 if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
527 if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask =
static_cast<umask_t
>((1ull << a_unroll) - 1);
529 static_assert(max_b_unroll <= 8,
"Unsupported max_b_unroll");
531 if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
532 if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
533 if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
534 if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
543 template <
int um,
int um_vecs,
int idx,
int uk,
bool fetch_x,
bool ktail>
545 int &fetchB_idx, vec &b_reg) {
553 template <
int um,
int um_vecs,
int idx,
int uk,
bool fetch_x,
bool ktail>
555 int &fetchB_idx, vec &b_reg) {
557 auto &c_reg = zmm[c_regs[um + idx * 3]];
558 auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
560 vfmadd(c_reg, a_reg, b_reg);
562 if (!fetch_x && um == 0 &&
563 (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
564 (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
565 prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
569 if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
570 prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
574 compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
579 template <
int um,
int um_vecs,
int uk,
int nelems,
bool ktail>
584 template <
int um,
int um_vecs,
int uk,
int nelems,
bool ktail>
587 auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
588 const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
589 a_load<nelems>(a_reg, a_addr);
591 load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
594 template <
int uk,
int pow,
int count,
int um_vecs,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch>
596 const Scalar *
const &ao,
597 const Scalar *
const &bo, Scalar *&co2,
598 int &fetchA_idx,
int &fetchB_idx) {
607 template <
int uk,
int pow,
int count,
int um_vecs,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch>
609 const Scalar *
const &ao,
610 const Scalar *
const &bo, Scalar *&co2,
611 int &fetchA_idx,
int &fetchB_idx) {
612 const int idx = (
pow / 2) + count;
614 if (count < (
pow + 1) / 2) {
615 auto &b_reg = zmm[b_regs[idx % 2]];
617 if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
618 if (fetch_x && uk == 3 && idx == 4) aa += 8;
620 if (b_unroll >=
pow) {
621 compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
623 const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
624 b_load(b_reg, b_addr);
628 innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
633 if (
pow == 2 && c_fetch) {
634 if (uk % 3 == 0 && uk > 0) {
637 prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
643 template <
int uk,
int max_b_unroll,
int a_unroll,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch,
bool no_a_preload = false>
644 EIGEN_ALWAYS_INLINE void innerkernel_1uk(
const Scalar *&aa,
const Scalar *
const &ao,
const Scalar *
const &bo,
645 Scalar *&co2,
int &fetchA_idx,
int &fetchB_idx) {
646 const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
648 if (max_b_unroll >= 1)
649 innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
650 if (max_b_unroll >= 2)
651 innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
652 if (max_b_unroll >= 4)
653 innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
654 if (max_b_unroll >= 8)
655 innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
658 if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
700 template <
int a_unroll,
int b_unroll,
int k_factor,
int max_b_unroll,
int max_k_factor,
bool c_fetch,
bool no_a_preload = false>
701 EIGEN_ALWAYS_INLINE void innerkernel(
const Scalar *&aa,
const Scalar *&ao,
const Scalar *&bo, Scalar *&co2) {
705 const bool fetch_x = k_factor == max_k_factor;
706 const bool ktail = k_factor == 1;
708 static_assert(k_factor <= 4 && k_factor > 0,
"innerkernel maximum k_factor supported is 4");
709 static_assert(no_a_preload ==
false || (no_a_preload ==
true && k_factor == 1),
"skipping a preload only allowed when k unroll is 1");
712 innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
715 innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
718 innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
721 innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
725 ao += a_unroll * k_factor;
726 bo += b_unroll * k_factor;
729 template <
int a_unroll,
int b_unroll,
int max_b_unroll>
730 EIGEN_ALWAYS_INLINE void kloop(
const Scalar *&aa,
const Scalar *&ao,
const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
731 const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
732 if (!use_less_a_regs && k > 1)
733 a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
735 a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
737 b_load(zmm[b_regs[0]], bo - b_shift + 0);
738 if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
741 prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
745 const int max_k_factor = 4;
746 Index kRem = k % max_k_factor;
748 if (k_ >= max_k_factor) {
750 kRem += max_k_factor;
752 Index loop_count = k_ / max_k_factor;
754 if (loop_count > 0) {
758 while (loop_count > 0) {
759 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
763 co2 = co1 + nelems_in_cache_line - 1;
765 loop_count += b_unroll;
766 while (loop_count > 0) {
767 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
772 while (loop_count > 0) {
773 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
781 while (loop_count > 1) {
782 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
785 if (loop_count > 0) {
786 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
790 c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
793 template <
int a_unroll,
int b_unroll,
int max_b_unroll>
794 EIGEN_ALWAYS_INLINE void nloop(
const Scalar *&aa,
const Scalar *&ao,
const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
796 ao =
a + a_off * a_unroll;
799 bo += b_unroll * b_off;
801 kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
804 bo += b_unroll * (b_stride - k - b_off);
810 template <
int a_unroll,
int max_a_unroll,
int max_b_unroll>
811 EIGEN_ALWAYS_INLINE void mloop(
const Scalar *&ao,
const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
813 const Scalar *aa =
a + a_unroll * a_stride;
817 if (a_unroll >= max_a_unroll) co2 =
c + 2 * ldc;
827 for (
Index i =
n / max_b_unroll;
i > 0;
i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
830 if (
n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
832 if (
n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
833 if (
n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
837 int n_rem = 2 * ((
n & 2) != 0) + 1 * ((
n & 1) != 0);
839 nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
845 a = ao + a_unroll * (a_stride - k - a_off);
850 template <
int max_a_unroll,
int max_b_unroll>
855 const Scalar *ao =
nullptr;
856 const Scalar *bo =
nullptr;
857 Scalar *co1 =
nullptr;
858 Scalar *co2 =
nullptr;
861 for (;
m >= max_a_unroll;
m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
864 if (
m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
865 if (
m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
866 if (
m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
867 if (
m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
868 if (
m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
869 if (
m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
874 int m_rem = 2 * ((
m & 2) != 0) + 1 * ((
m & 1) != 0);
876 mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
883 const Scalar *b_, Scalar *c_,
bool is_alpha1_,
bool is_beta0_,
Index a_stride_,
Index b_stride_,
894 is_alpha1(is_alpha1_),
901 zmm[8] =
pzero(zmm[8]);
902 zmm[9] =
pzero(zmm[9]);
903 zmm[10] =
pzero(zmm[10]);
904 zmm[11] =
pzero(zmm[11]);
905 zmm[12] =
pzero(zmm[12]);
906 zmm[13] =
pzero(zmm[13]);
907 zmm[14] =
pzero(zmm[14]);
908 zmm[15] =
pzero(zmm[15]);
909 zmm[16] =
pzero(zmm[16]);
910 zmm[17] =
pzero(zmm[17]);
911 zmm[18] =
pzero(zmm[18]);
912 zmm[19] =
pzero(zmm[19]);
913 zmm[20] =
pzero(zmm[20]);
914 zmm[21] =
pzero(zmm[21]);
915 zmm[22] =
pzero(zmm[22]);
916 zmm[23] =
pzero(zmm[23]);
917 zmm[24] =
pzero(zmm[24]);
918 zmm[25] =
pzero(zmm[25]);
919 zmm[26] =
pzero(zmm[26]);
920 zmm[27] =
pzero(zmm[27]);
921 zmm[28] =
pzero(zmm[28]);
922 zmm[29] =
pzero(zmm[29]);
923 zmm[30] =
pzero(zmm[30]);
924 zmm[31] =
pzero(zmm[31]);
935 template <
typename Scalar,
int max_a_unroll,
int max_b_unroll,
bool is_alpha1,
bool is_beta0,
bool is_unit_inc>
939 if (a_stride == -1) a_stride = k;
940 if (b_stride == -1) b_stride = k;
942 gemm_class<Scalar, is_unit_inc> g(
m,
n, k, ldc, inc, alpha,
a,
b,
c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
944 g.template compute_kern<max_a_unroll, max_b_unroll>();
948 #if EIGEN_USE_AVX512_GEMM_KERNELS
949 template <
bool ConjLhs_,
bool ConjRhs_,
int PacketSize_>
950 class gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::
Target, PacketSize_>
951 :
public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
952 using Base = gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
955 enum { nr = Base::Vectorizable ? 8 : 4 };
958 template <
bool ConjLhs_,
bool ConjRhs_,
int PacketSize_>
959 class gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::
Target, PacketSize_>
960 :
public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
961 using Base = gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
964 enum { nr = Base::Vectorizable ? 8 : 4 };
967 template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
968 struct gemm_pack_rhs<Scalar,
Index, DataMapper, 8,
ColMajor, Conjugate, PanelMode> {
969 typedef typename packet_traits<Scalar>::type Packet;
970 typedef typename DataMapper::LinearMapper LinearMapper;
976 template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
979 constexpr
int nr = 8;
983 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
984 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
985 Index packet_cols8 = nr >= 8 ? (
cols / 8) * 8 : 0;
986 Index packet_cols4 = nr >= 4 ? (
cols / 4) * 4 : 0;
988 const Index peeled_k = (depth / PacketSize) * PacketSize;
990 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
992 if (PanelMode) count += 8 * offset;
993 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
994 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
995 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
996 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
997 const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
998 const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
999 const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
1000 const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
1002 if ((PacketSize % 8) == 0)
1004 for (; k < peeled_k; k += PacketSize) {
1005 PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
1007 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1008 kernel.packet[1] = dm1.template loadPacket<Packet>(k);
1009 kernel.packet[2] = dm2.template loadPacket<Packet>(k);
1010 kernel.packet[3] = dm3.template loadPacket<Packet>(k);
1011 kernel.packet[4] = dm4.template loadPacket<Packet>(k);
1012 kernel.packet[5] = dm5.template loadPacket<Packet>(k);
1013 kernel.packet[6] = dm6.template loadPacket<Packet>(k);
1014 kernel.packet[7] = dm7.template loadPacket<Packet>(k);
1018 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1019 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1020 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1021 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1022 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
1023 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
1024 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
1025 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
1026 count += 8 * PacketSize;
1029 for (; k < depth; k++) {
1030 blockB[count + 0] = cj(dm0(k));
1031 blockB[count + 1] = cj(dm1(k));
1032 blockB[count + 2] = cj(dm2(k));
1033 blockB[count + 3] = cj(dm3(k));
1034 blockB[count + 4] = cj(dm4(k));
1035 blockB[count + 5] = cj(dm5(k));
1036 blockB[count + 6] = cj(dm6(k));
1037 blockB[count + 7] = cj(dm7(k));
1041 if (PanelMode) count += 8 * (stride - offset - depth);
1046 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1048 if (PanelMode) count += 4 * offset;
1049 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1050 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1051 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1052 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1055 if ((PacketSize % 4) == 0)
1057 for (; k < peeled_k; k += PacketSize) {
1058 PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
1059 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1060 kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
1061 kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
1062 kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
1064 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1065 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1066 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1067 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1068 count += 4 * PacketSize;
1071 for (; k < depth; k++) {
1072 blockB[count + 0] = cj(dm0(k));
1073 blockB[count + 1] = cj(dm1(k));
1074 blockB[count + 2] = cj(dm2(k));
1075 blockB[count + 3] = cj(dm3(k));
1079 if (PanelMode) count += 4 * (stride - offset - depth);
1084 for (
Index j2 = packet_cols4; j2 <
cols; ++j2) {
1085 if (PanelMode) count += offset;
1086 const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
1087 for (
Index k = 0; k < depth; k++) {
1088 blockB[count] = cj(dm0(k));
1091 if (PanelMode) count += (stride - offset - depth);
1095 template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
1096 struct gemm_pack_rhs<Scalar,
Index, DataMapper, 8,
RowMajor, Conjugate, PanelMode> {
1097 typedef typename packet_traits<Scalar>::type Packet;
1098 typedef typename unpacket_traits<Packet>::half HalfPacket;
1099 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
1100 typedef typename DataMapper::LinearMapper LinearMapper;
1108 constexpr
int nr = 8;
1112 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
1113 const bool HasHalf = (int)HalfPacketSize < (
int)PacketSize;
1114 const bool HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize;
1115 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
1116 Index packet_cols8 = nr >= 8 ? (
cols / 8) * 8 : 0;
1117 Index packet_cols4 = nr >= 4 ? (
cols / 4) * 4 : 0;
1121 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1123 if (PanelMode) count += 8 * offset;
1124 for (
Index k = 0; k < depth; k++) {
1125 if (PacketSize == 8) {
1127 Packet
A = rhs.template loadPacket<Packet>(k, j2);
1128 pstoreu(blockB + count, cj.pconj(
A));
1129 }
else if (HasHalf && HalfPacketSize == 8) {
1130 HalfPacket
A = rhs.template loadPacket<HalfPacket>(k, j2);
1131 pstoreu(blockB + count, cj.pconj(
A));
1132 }
else if (HasQuarter && QuarterPacketSize == 8) {
1133 QuarterPacket
A = rhs.template loadPacket<QuarterPacket>(k, j2);
1134 pstoreu(blockB + count, cj.pconj(
A));
1135 }
else if (PacketSize == 4) {
1138 Packet
A = rhs.template loadPacket<Packet>(k, j2);
1139 Packet
B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
1140 pstoreu(blockB + count, cj.pconj(
A));
1141 pstoreu(blockB + count + PacketSize, cj.pconj(
B));
1144 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1145 blockB[count + 0] = cj(dm0(0));
1146 blockB[count + 1] = cj(dm0(1));
1147 blockB[count + 2] = cj(dm0(2));
1148 blockB[count + 3] = cj(dm0(3));
1149 blockB[count + 4] = cj(dm0(4));
1150 blockB[count + 5] = cj(dm0(5));
1151 blockB[count + 6] = cj(dm0(6));
1152 blockB[count + 7] = cj(dm0(7));
1157 if (PanelMode) count += 8 * (stride - offset - depth);
1162 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1164 if (PanelMode) count += 4 * offset;
1165 for (
Index k = 0; k < depth; k++) {
1166 if (PacketSize == 4) {
1167 Packet
A = rhs.template loadPacket<Packet>(k, j2);
1168 pstoreu(blockB + count, cj.pconj(
A));
1169 count += PacketSize;
1170 }
else if (HasHalf && HalfPacketSize == 4) {
1171 HalfPacket
A = rhs.template loadPacket<HalfPacket>(k, j2);
1172 pstoreu(blockB + count, cj.pconj(
A));
1173 count += HalfPacketSize;
1174 }
else if (HasQuarter && QuarterPacketSize == 4) {
1175 QuarterPacket
A = rhs.template loadPacket<QuarterPacket>(k, j2);
1176 pstoreu(blockB + count, cj.pconj(
A));
1177 count += QuarterPacketSize;
1179 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1180 blockB[count + 0] = cj(dm0(0));
1181 blockB[count + 1] = cj(dm0(1));
1182 blockB[count + 2] = cj(dm0(2));
1183 blockB[count + 3] = cj(dm0(3));
1188 if (PanelMode) count += 4 * (stride - offset - depth);
1192 for (
Index j2 = packet_cols4; j2 <
cols; ++j2) {
1193 if (PanelMode) count += offset;
1194 for (
Index k = 0; k < depth; k++) {
1195 blockB[count] = cj(rhs(k, j2));
1198 if (PanelMode) count += stride - offset - depth;
1203 template <
typename Scalar,
typename Index,
typename DataMapper,
int mr,
bool ConjugateLhs,
bool ConjugateRhs>
1204 struct gebp_kernel<Scalar, Scalar,
Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs> {
1211 template <
typename Scalar,
typename Index,
typename DataMapper,
int mr,
bool ConjugateLhs,
bool ConjugateRhs>
1215 if (
res.incr() == 1) {
1217 gemm_kern_avx512<Scalar, mr, 8, true, false, true>(
rows,
cols, depth, &alpha, blockA, blockB,
1218 (Scalar *)
res.data(),
res.stride(),
res.incr(), strideA,
1219 strideB, offsetA, offsetB);
1221 gemm_kern_avx512<Scalar, mr, 8, false, false, true>(
rows,
cols, depth, &alpha, blockA, blockB,
1222 (Scalar *)
res.data(),
res.stride(),
res.incr(), strideA,
1223 strideB, offsetA, offsetB);
1227 gemm_kern_avx512<Scalar, mr, 8, true, false, false>(
rows,
cols, depth, &alpha, blockA, blockB,
1228 (Scalar *)
res.data(),
res.stride(),
res.incr(), strideA,
1229 strideB, offsetA, offsetB);
1231 gemm_kern_avx512<Scalar, mr, 8, false, false, false>(
rows,
cols, depth, &alpha, blockA, blockB,
1232 (Scalar *)
res.data(),
res.stride(),
res.incr(), strideA,
1233 strideB, offsetA, offsetB);
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_ASM_COMMENT(X)
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_DONT_INLINE
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
bfloat16 pow(const bfloat16 &a, const bfloat16 &b)
Packet4d ploadu< Packet4d >(const double *from)
Packet padd(const Packet &a, const Packet &b)
Packet8f pzero(const Packet8f &)
Packet8d pload1< Packet8d >(const double *from)
void pstores(Scalar *to, const Packet &from)
void pstorel(Scalar *to, const Packet &from)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
void pstoreu(Scalar *to, const Packet &from)
Packet padds(const Packet &a, const Packet &b)
Packet pmul(const Packet &a, const Packet &b)
void pscatter(Scalar *to, const Packet &from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Packet4f ploadu< Packet4f >(const float *from)
Packet pgather(const Packet &src, const Scalar *from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c, Index ldc, Index inc=1, Index a_stride=-1, Index b_stride=-1, Index a_off=0, Index b_off=0)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.