GemmKernel.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2022 Intel Corporation
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
11 #define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
12 
13 #if EIGEN_COMP_MSVC
14 #include <intrin.h>
15 #else
16 #include <x86intrin.h>
17 #endif
18 #include <immintrin.h>
19 #include <type_traits>
20 
21 #include "../../InternalHeaderCheck.h"
22 
23 #if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
24 #define EIGEN_USE_AVX512_GEMM_KERNELS 1
25 #endif
26 
27 #define SECOND_FETCH (32)
28 #if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
29 // Use less registers to load A elements to workaround compiler spills. Loose a
30 // bit of performance (less than ~2%).
31 #define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
32 #endif
33 
34 namespace Eigen {
35 namespace internal {
36 
37 template <typename Scalar, bool is_unit_inc>
38 class gemm_class {
39  using vec = typename packet_traits<Scalar>::type;
40  using vec_ymm = typename unpacket_traits<vec>::half;
41  using vec_xmm = typename unpacket_traits<vec_ymm>::half;
42  using umask_t = typename unpacket_traits<vec>::mask_t;
43 
44  static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float);
45  static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double);
46 
47 #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
48  static constexpr bool use_less_a_regs = !is_unit_inc;
49 #else
50  static constexpr bool use_less_a_regs = true;
51 #endif
52 #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
53  static constexpr bool use_less_b_regs = !is_unit_inc;
54 #else
55  static constexpr bool use_less_b_regs = true;
56 #endif
57 
58  static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5};
59  static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7};
60  static constexpr int c_regs[] = {
61  8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
62  };
63 
64  static constexpr int alpha_load_reg = 0;
65  static constexpr int c_load_regs[] = {1, 2, 6};
66 
67  static constexpr int a_shift = 128;
68  static constexpr int b_shift = 128;
69 
70  static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
71  static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
72  static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
73 
74  vec zmm[32];
75  umask_t mask;
76 
77  // gemm arguments.
78  Index m;
79  const Index n, k, ldc;
80  const Index inc;
81  const Scalar *alpha;
82 
83  const Scalar *a, *b;
84  Scalar *c;
85 
86  const bool is_alpha1;
87  const bool is_beta0;
88 
89  const Index a_stride, b_stride;
90  const Index a_off, b_off;
91 
92  static EIGEN_ALWAYS_INLINE constexpr int div_up(int a, int b) { return a == 0 ? 0 : (a - 1) / b + 1; }
93 
94  EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) {
95  _mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
96  }
97 
98  EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) {
99  _mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
100  }
101 
102  EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) { _mm_prefetch((char *)(x_addr - a_shift), _MM_HINT_T2); }
103 
104  EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) {
105 #if defined(__PRFCHW__) && __PRFCHW__ == 1
106  _m_prefetchw((void *)c_addr);
107 #else
108  _mm_prefetch((char *)c_addr, _MM_HINT_T0);
109 #endif
110  }
111 
112  template <int nelems>
113  EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) {
114  switch (nelems * sizeof(*a_addr) * 8) {
115  default:
116  case 512 * 3:
117  a_reg = ploadu<vec>(a_addr);
118  break;
119  case 512 * 2:
120  a_reg = ploadu<vec>(a_addr);
121  break;
122  case 512 * 1:
123  a_reg = ploadu<vec>(a_addr);
124  break;
125  case 256 * 1:
126  a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr))));
127  break;
128  case 128 * 1:
129  a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr))));
130  break;
131  case 64 * 1:
132  a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr)));
133  break;
134  case 32 * 1:
135  a_reg = pload1<vec>(a_addr);
136  break;
137  }
138  }
139 
140  EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) { b_reg = pload1<vec>(b_addr); }
141 
142  template <int nelems>
143  EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) {
144  if (is_unit_inc) {
145  switch (nelems * sizeof(*mem) * 8) {
146  default:
147  case 512 * 3:
148  pstoreu(mem, src);
149  break;
150  case 512 * 2:
151  pstoreu(mem, src);
152  break;
153  case 512 * 1:
154  pstoreu(mem, src);
155  break;
156  case 256 * 1:
157  pstoreu(mem, preinterpret<vec_ymm>(src));
158  break;
159  case 128 * 1:
160  pstoreu(mem, preinterpret<vec_xmm>(src));
161  break;
162  case 64 * 1:
163  pstorel(mem, preinterpret<vec_xmm>(src));
164  break;
165  case 32 * 1:
166  pstores(mem, preinterpret<vec_xmm>(src));
167  break;
168  }
169  } else {
170  switch (nelems * sizeof(*mem) * 8) {
171  default:
172  case 512 * 3:
173  pscatter(mem, src, inc);
174  break;
175  case 512 * 2:
176  pscatter(mem, src, inc);
177  break;
178  case 512 * 1:
179  pscatter(mem, src, inc);
180  break;
181  case 256 * 1:
182  pscatter(mem, src, inc, mask);
183  break;
184  case 128 * 1:
185  pscatter(mem, src, inc, mask);
186  break;
187  case 64 * 1:
188  pscatter(mem, src, inc, mask);
189  break;
190  case 32 * 1:
191  pscatter(mem, src, inc, mask);
192  break;
193  }
194  }
195  }
196 
197  template <int nelems>
198  EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec &reg) {
199  if (is_unit_inc) {
200  switch (nelems * sizeof(*mem) * 8) {
201  default:
202  case 512 * 3:
203  dst = padd(src, ploadu<vec>(mem));
204  break;
205  case 512 * 2:
206  dst = padd(src, ploadu<vec>(mem));
207  break;
208  case 512 * 1:
209  dst = padd(src, ploadu<vec>(mem));
210  break;
211  case 256 * 1:
212  dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
213  break;
214  case 128 * 1:
215  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
216  break;
217  case 64 * 1:
218  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
219  break;
220  case 32 * 1:
221  dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
222  break;
223  }
224  } else {
225  // Zero out scratch register
226  reg = pzero(reg);
227 
228  switch (nelems * sizeof(*mem) * 8) {
229  default:
230  case 512 * 3:
231  reg = pgather<Scalar, vec>(mem, inc);
232  dst = padd(src, reg);
233  break;
234  case 512 * 2:
235  reg = pgather<Scalar, vec>(mem, inc);
236  dst = padd(src, reg);
237  break;
238  case 512 * 1:
239  reg = pgather<Scalar, vec>(mem, inc);
240  dst = padd(src, reg);
241  break;
242  case 256 * 1:
243  reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
244  dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
245  break;
246  case 128 * 1:
247  reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
248  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
249  break;
250  case 64 * 1:
251  if (is_f32) {
252  reg = pgather(reg, mem, inc, mask);
253  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
254  } else {
255  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
256  }
257  break;
258  case 32 * 1:
259  dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
260  break;
261  }
262  }
263  }
264 
265  EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) {
266  dst = pmadd(src1, src2, dst);
267 
268 #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
269  // Workaround register spills for gcc and clang
270  __asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2));
271 #endif
272  }
273 
274  template <int nelems>
275  EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec &reg) {
276  if (is_unit_inc) {
277  switch (nelems * sizeof(*mem) * 8) {
278  default:
279  case 512 * 3:
280  dst = pmadd(scale, src, ploadu<vec>(mem));
281  break;
282  case 512 * 2:
283  dst = pmadd(scale, src, ploadu<vec>(mem));
284  break;
285  case 512 * 1:
286  dst = pmadd(scale, src, ploadu<vec>(mem));
287  break;
288  case 256 * 1:
289  dst =
290  preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
291  break;
292  case 128 * 1:
293  dst =
294  preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
295  break;
296  case 64 * 1:
297  dst =
298  preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
299  break;
300  case 32 * 1:
301  dst =
302  preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
303  break;
304  }
305  } else {
306  // Zero out scratch register
307  reg = pzero(reg);
308 
309  switch (nelems * sizeof(*mem) * 8) {
310  default:
311  case 512 * 3:
312  reg = pgather<Scalar, vec>(mem, inc);
313  dst = pmadd(scale, src, reg);
314  break;
315  case 512 * 2:
316  reg = pgather<Scalar, vec>(mem, inc);
317  dst = pmadd(scale, src, reg);
318  break;
319  case 512 * 1:
320  reg = pgather<Scalar, vec>(mem, inc);
321  dst = pmadd(scale, src, reg);
322  break;
323  case 256 * 1:
324  reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
325  dst = preinterpret<vec>(
326  pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
327  break;
328  case 128 * 1:
329  reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
330  dst = preinterpret<vec>(
331  pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
332  break;
333  case 64 * 1:
334  if (is_f32) {
335  reg = pgather(reg, mem, inc, mask);
336  dst = preinterpret<vec>(
337  pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
338  } else {
339  dst = preinterpret<vec>(
340  pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
341  }
342  break;
343  case 32 * 1:
344  dst =
345  preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
346  break;
347  }
348  }
349  }
350 
351  template <int j, int endX, int i, int endY, int nelems>
352  EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(const Scalar *ao) {
354  }
355 
356  template <int j, int endX, int i, int endY, int nelems>
357  EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(const Scalar *ao) {
358  if (j < endX) {
359  if (i < endY) {
360  auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
361  const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
362  a_load<nelems>(a_reg, a_addr);
363 
364  a_loads<j, endX, i + 1, endY, nelems>(ao);
365  } else {
366  a_loads<j + 1, endX, 0, endY, nelems>(ao);
367  }
368  }
369  }
370 
371  template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
372  EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(const Scalar *co1,
373  const Scalar *co2) {
376  }
377 
378  /* C prefetch loop structure.
379  * for (int un = 0; un < 8; un++) {
380  * if (b_unroll >= un + 1) {
381  * if (un == 4) co2 = co1 + 4 * ldc;
382  *
383  * for (int i = 0; i < um_vecs; i++) {
384  * Scalar *co = (un + 1 <= 4) ? co1 : co2;
385  * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
386  * prefetch_c(co + co_off);
387  * }
388  * }
389  * }
390  */
391 
392  template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
393  EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) {
394  if (un < max_b_unroll) {
395  if (b_unroll >= un + 1) {
396  if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
397 
398  if (i < um_vecs) {
399  Scalar *co = (un + 1 <= 4) ? co1 : co2;
400  auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
401  prefetch_c(co + co_off);
402 
403  prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
404  } else {
405  prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
406  }
407 
408  } else {
409  prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
410  }
411  }
412  }
413 
414  // load_c
415  template <int i, int um_vecs, int idx, int nelems>
416  EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
418  EIGEN_UNUSED_VARIABLE(alpha_reg);
419  }
420 
421  template <int i, int um_vecs, int idx, int nelems>
422  EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
423  if (i < um_vecs) {
424  auto &c_reg = zmm[c_regs[i + idx * 3]];
425  auto &c_load_reg = zmm[c_load_regs[i % 3]];
426  auto c_mem = cox;
427  if (is_unit_inc)
428  c_mem += i * nelems_in_cache_line;
429  else
430  c_mem += i * nelems_in_cache_line * inc;
431 
432  if (!is_beta0 && is_alpha1)
433  vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
434  else if (!is_beta0 && !is_alpha1)
435  vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
436  else if (is_beta0 && !is_alpha1)
437  c_reg = pmul(alpha_reg, c_reg);
438 
439  scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
440  }
441  }
442 
443  // store_c
444  template <int i, int um_vecs, int idx, int nelems>
445  EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) {
447  }
448 
449  template <int i, int um_vecs, int idx, int nelems>
450  EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) {
451  if (i < um_vecs) {
452  auto &c_reg = zmm[c_regs[i + idx * 3]];
453  auto c_mem = cox;
454  if (is_unit_inc)
455  c_mem += i * nelems_in_cache_line;
456  else
457  c_mem += i * nelems_in_cache_line * inc;
458 
459  c_store<nelems>(c_mem, c_reg);
460  c_reg = pzero(c_reg);
461 
462  write_c<i + 1, um_vecs, idx, nelems>(cox);
463  }
464  }
465 
466  /* C update loop structure.
467  * co2 = co1 + ldc;
468  *
469  * auto &alpha_reg = zmm[alpha_load_reg];
470  * if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
471  *
472  * int idx = 0;
473  * for (pow = 1; pow <= 8; pow <<= 1) {
474  *
475  * if (b_unroll >= pow) {
476  * for (count = 1; count < (pow + 1) / 2 + 1; count++) {
477  * if (pow >= 4) co2 += ldc;
478  *
479  * const Scalar *cox = (idx == 0) ? co1 : co2;
480  *
481  * const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
482  * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
483  * write_c<0, um_vecs, idx, a_unroll>(cox);
484  *
485  * idx++;
486  * }
487  * }
488  * }
489  *
490  * if (b_unroll == 1)
491  * co1 += ldc;
492  * else
493  * co1 = co2 + ldc;
494  */
495 
496  template <int pow, int a_unroll, int idx>
497  EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) {
498  if (pow >= 4) cox += ldc;
499 
500  const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
501  auto &alpha_reg = zmm[alpha_load_reg];
502 
503  scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
504  write_c<0, um_vecs, idx, a_unroll>(cox);
505  }
506 
507  template <int pow, int a_unroll>
508  EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) {
509  constexpr int idx = pow / 2;
510  Scalar *&cox = idx == 0 ? co1 : co2;
511 
512  constexpr int max_count = (pow + 1) / 2;
513  static_assert(max_count <= 4, "Unsupported max_count.");
514 
515  if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
516  if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
517  if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
518  if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
519  }
520 
521  template <int max_b_unroll, int a_unroll, int b_unroll>
522  EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) {
523  auto &alpha_reg = zmm[alpha_load_reg];
524 
525  co2 = co1 + ldc;
526  if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
527  if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast<umask_t>((1ull << a_unroll) - 1);
528 
529  static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll");
530 
531  if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
532  if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
533  if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
534  if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
535 
536  if (b_unroll == 1)
537  co1 += ldc;
538  else
539  co1 = co2 + ldc;
540  }
541 
542  // compute
543  template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
544  EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
545  int &fetchB_idx, vec &b_reg) {
548  EIGEN_UNUSED_VARIABLE(fetchA_idx);
549  EIGEN_UNUSED_VARIABLE(fetchB_idx);
550  EIGEN_UNUSED_VARIABLE(b_reg);
551  }
552 
553  template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
554  EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
555  int &fetchB_idx, vec &b_reg) {
556  if (um < um_vecs) {
557  auto &c_reg = zmm[c_regs[um + idx * 3]];
558  auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
559 
560  vfmadd(c_reg, a_reg, b_reg);
561 
562  if (!fetch_x && um == 0 &&
563  (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
564  (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
565  prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
566  fetchA_idx++;
567  }
568 
569  if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
570  prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
571  fetchB_idx++;
572  }
573 
574  compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
575  }
576  }
577 
578  // load_a
579  template <int um, int um_vecs, int uk, int nelems, bool ktail>
580  EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao) {
582  }
583 
584  template <int um, int um_vecs, int uk, int nelems, bool ktail>
585  EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(const Scalar *ao) {
586  if (um < um_vecs) {
587  auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
588  const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
589  a_load<nelems>(a_reg, a_addr);
590 
591  load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
592  }
593  }
594  template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
595  EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
596  const Scalar *const &ao,
597  const Scalar *const &bo, Scalar *&co2,
598  int &fetchA_idx, int &fetchB_idx) {
603  EIGEN_UNUSED_VARIABLE(fetchA_idx);
604  EIGEN_UNUSED_VARIABLE(fetchB_idx);
605  }
606 
607  template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
608  EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
609  const Scalar *const &ao,
610  const Scalar *const &bo, Scalar *&co2,
611  int &fetchA_idx, int &fetchB_idx) {
612  const int idx = (pow / 2) + count;
613 
614  if (count < (pow + 1) / 2) {
615  auto &b_reg = zmm[b_regs[idx % 2]];
616 
617  if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
618  if (fetch_x && uk == 3 && idx == 4) aa += 8;
619 
620  if (b_unroll >= pow) {
621  compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
622 
623  const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
624  b_load(b_reg, b_addr);
625  }
626 
627  // Go to the next count.
628  innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
629  fetchB_idx);
630 
631  } else {
632  // Maybe prefetch C data after count-loop.
633  if (pow == 2 && c_fetch) {
634  if (uk % 3 == 0 && uk > 0) {
635  co2 += ldc;
636  } else {
637  prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
638  }
639  }
640  }
641  }
642 
643  template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch, bool no_a_preload = false>
644  EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
645  Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
646  const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
647 
648  if (max_b_unroll >= 1)
649  innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
650  if (max_b_unroll >= 2)
651  innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
652  if (max_b_unroll >= 4)
653  innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
654  if (max_b_unroll >= 8)
655  innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
656 
657  // Load A after pow-loop. Skip this at the end to prevent running over the buffer
658  if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
659  }
660 
661  /* Inner kernel loop structure.
662  * for (int uk = 0; uk < kfactor; uk++) {
663  * int idx = 0;
664  *
665  * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) {
666  * for (int count = 0; count < (pow + 1) / 2; count++) {
667  * auto &b_reg = zmm[b_regs[idx % 2]];
668  *
669  * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
670  * if (fetch_x && uk == 3 && idx == 4) aa += 8;
671  *
672  * if (b_unroll >= pow) {
673  * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
674  *
675  * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ;
676  * b_load(b_reg, b_addr);
677  * }
678  * idx++;
679  * }
680  *
681  * Maybe prefetch C data.
682  * if (pow == 2 && c_fetch) {
683  * if (uk % 3 == 0 && uk > 0) {
684  * co2 += ldc;
685  * } else {
686  * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
687  * }
688  * }
689  * }
690  *
691  * Load A.
692  * load_a<0, um_vecs, uk, ktail, a_unroll>(ao);
693  * }
694  *
695  * Advance A/B pointers after uk-loop.
696  * ao += a_unroll * kfactor;
697  * bo += b_unroll * kfactor;
698  */
699 
700  template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch, bool no_a_preload = false>
701  EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
702  int fetchA_idx = 0;
703  int fetchB_idx = 0;
704 
705  const bool fetch_x = k_factor == max_k_factor;
706  const bool ktail = k_factor == 1;
707 
708  static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
709  static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1), "skipping a preload only allowed when k unroll is 1");
710 
711  if (k_factor > 0)
712  innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
713  fetchB_idx);
714  if (k_factor > 1)
715  innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
716  fetchB_idx);
717  if (k_factor > 2)
718  innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
719  fetchB_idx);
720  if (k_factor > 3)
721  innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
722  fetchB_idx);
723 
724  // Advance A/B pointers after uk-loop.
725  ao += a_unroll * k_factor;
726  bo += b_unroll * k_factor;
727  }
728 
729  template <int a_unroll, int b_unroll, int max_b_unroll>
730  EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
731  const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
732  if (!use_less_a_regs && k > 1)
733  a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
734  else
735  a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
736 
737  b_load(zmm[b_regs[0]], bo - b_shift + 0);
738  if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
739 
740 #ifndef SECOND_FETCH
741  prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
742 #endif // SECOND_FETCH
743 
744  // Unrolling k-loop by a factor of 4.
745  const int max_k_factor = 4;
746  Index kRem = k % max_k_factor;
747  Index k_ = k - kRem;
748  if (k_ >= max_k_factor) {
749  k_ -= max_k_factor;
750  kRem += max_k_factor;
751  }
752  Index loop_count = k_ / max_k_factor;
753 
754  if (loop_count > 0) {
755 #ifdef SECOND_FETCH
756  loop_count -= SECOND_FETCH;
757 #endif
758  while (loop_count > 0) {
759  innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
760  loop_count--;
761  }
762 #ifdef SECOND_FETCH
763  co2 = co1 + nelems_in_cache_line - 1;
764 
765  loop_count += b_unroll;
766  while (loop_count > 0) {
767  innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
768  loop_count--;
769  }
770 
771  loop_count += SECOND_FETCH - b_unroll;
772  while (loop_count > 0) {
773  innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
774  loop_count--;
775  }
776 #endif
777  }
778 
779  // k-loop remainder handling.
780  loop_count = kRem;
781  while (loop_count > 1) {
782  innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
783  loop_count--;
784  }
785  if (loop_count > 0) {
786  innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
787  }
788 
789  // Update C matrix.
790  c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
791  }
792 
793  template <int a_unroll, int b_unroll, int max_b_unroll>
794  EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
795  // Set A matrix pointer.
796  ao = a + a_off * a_unroll;
797 
798  // Set B matrix pointer if needed.
799  bo += b_unroll * b_off;
800 
801  kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
802 
803  // Advance B matrix pointer if needed.
804  bo += b_unroll * (b_stride - k - b_off);
805 
806  // Advance prefetch A pointer.
807  aa += 16;
808  }
809 
810  template <int a_unroll, int max_a_unroll, int max_b_unroll>
811  EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
812  // Set prefetch A pointers.
813  const Scalar *aa = a + a_unroll * a_stride;
814 
815  // Set C matrix pointers.
816  co1 = c;
817  if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
818  if (is_unit_inc)
819  c += a_unroll;
820  else
821  c += a_unroll * inc;
822 
823  // Set B matrix pointer.
824  bo = b;
825 
826  // Main n-loop.
827  for (Index i = n / max_b_unroll; i > 0; i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
828 
829  // n-remainders.
830  if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
831 #if 0
832  if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
833  if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
834 #else
835  // Copy kernels don't support tails of n = 2 for single/double precision.
836  // Loop over ones.
837  int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
838  while (n_rem > 0) {
839  nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
840  n_rem--;
841  }
842 #endif
843 
844  // Advance A matrix pointer.
845  a = ao + a_unroll * (a_stride - k - a_off);
846  }
847 
848  public:
849  // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll.
850  template <int max_a_unroll, int max_b_unroll>
851  EIGEN_ALWAYS_INLINE void compute_kern() {
852  a -= -a_shift;
853  b -= -b_shift;
854 
855  const Scalar *ao = nullptr;
856  const Scalar *bo = nullptr;
857  Scalar *co1 = nullptr;
858  Scalar *co2 = nullptr;
859 
860  // Main m-loop.
861  for (; m >= max_a_unroll; m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
862 
863  // m-remainders.
864  if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
865  if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
866  if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
867  if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
868  if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
869  if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
870 
871  // Copy kernels don't support tails of m = 2 for single precision.
872  // Loop over ones.
873  if (is_f32) {
874  int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
875  while (m_rem > 0) {
876  mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
877  m_rem--;
878  }
879  }
880  }
881 
882  gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_,
883  const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_,
884  Index a_off_, Index b_off_)
885  : m(m_),
886  n(n_),
887  k(k_),
888  ldc(ldc_),
889  inc(inc_),
890  alpha(alpha_),
891  a(a_),
892  b(b_),
893  c(c_),
894  is_alpha1(is_alpha1_),
895  is_beta0(is_beta0_),
896  a_stride(a_stride_),
897  b_stride(b_stride_),
898  a_off(a_off_),
899  b_off(b_off_) {
900  // Zero out all accumulation registers.
901  zmm[8] = pzero(zmm[8]);
902  zmm[9] = pzero(zmm[9]);
903  zmm[10] = pzero(zmm[10]);
904  zmm[11] = pzero(zmm[11]);
905  zmm[12] = pzero(zmm[12]);
906  zmm[13] = pzero(zmm[13]);
907  zmm[14] = pzero(zmm[14]);
908  zmm[15] = pzero(zmm[15]);
909  zmm[16] = pzero(zmm[16]);
910  zmm[17] = pzero(zmm[17]);
911  zmm[18] = pzero(zmm[18]);
912  zmm[19] = pzero(zmm[19]);
913  zmm[20] = pzero(zmm[20]);
914  zmm[21] = pzero(zmm[21]);
915  zmm[22] = pzero(zmm[22]);
916  zmm[23] = pzero(zmm[23]);
917  zmm[24] = pzero(zmm[24]);
918  zmm[25] = pzero(zmm[25]);
919  zmm[26] = pzero(zmm[26]);
920  zmm[27] = pzero(zmm[27]);
921  zmm[28] = pzero(zmm[28]);
922  zmm[29] = pzero(zmm[29]);
923  zmm[30] = pzero(zmm[30]);
924  zmm[31] = pzero(zmm[31]);
925  }
926 };
927 
928 // Compute kernel with max unroll support of:
929 // Single precision:
930 // max_a_unroll: 48, 32, 16, 8, 4, 2, 1
931 // max_b_unroll: 8, 4, 2, 1
932 // Double precision:
933 // max_a_unroll: 24, 16, 8, 4, 2, 1
934 // max_b_unroll: 8, 4, 2, 1
935 template <typename Scalar, int max_a_unroll, int max_b_unroll, bool is_alpha1, bool is_beta0, bool is_unit_inc>
936 EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b,
937  Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1,
938  Index a_off = 0, Index b_off = 0) {
939  if (a_stride == -1) a_stride = k;
940  if (b_stride == -1) b_stride = k;
941 
942  gemm_class<Scalar, is_unit_inc> g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
943  b_off);
944  g.template compute_kern<max_a_unroll, max_b_unroll>();
945 }
946 
947 // Template specializations of GEBP kernels with nr = 8.
948 #if EIGEN_USE_AVX512_GEMM_KERNELS
949 template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
950 class gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
951  : public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
952  using Base = gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
953 
954  public:
955  enum { nr = Base::Vectorizable ? 8 : 4 };
956 };
957 
958 template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
959 class gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
960  : public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
961  using Base = gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
962 
963  public:
964  enum { nr = Base::Vectorizable ? 8 : 4 };
965 };
966 
967 template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
968 struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode> {
969  typedef typename packet_traits<Scalar>::type Packet;
970  typedef typename DataMapper::LinearMapper LinearMapper;
971  enum { PacketSize = packet_traits<Scalar>::size };
972  EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
973  Index offset = 0);
974 };
975 
976 template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
978  Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) {
979  constexpr int nr = 8;
980  EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
981  EIGEN_UNUSED_VARIABLE(stride);
982  EIGEN_UNUSED_VARIABLE(offset);
983  eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
984  conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
985  Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
986  Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
987  Index count = 0;
988  const Index peeled_k = (depth / PacketSize) * PacketSize;
989  if (nr >= 8) {
990  for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
991  // skip what we have before
992  if (PanelMode) count += 8 * offset;
993  const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
994  const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
995  const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
996  const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
997  const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
998  const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
999  const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
1000  const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
1001  Index k = 0;
1002  if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4
1003  {
1004  for (; k < peeled_k; k += PacketSize) {
1005  PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
1006 
1007  kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1008  kernel.packet[1] = dm1.template loadPacket<Packet>(k);
1009  kernel.packet[2] = dm2.template loadPacket<Packet>(k);
1010  kernel.packet[3] = dm3.template loadPacket<Packet>(k);
1011  kernel.packet[4] = dm4.template loadPacket<Packet>(k);
1012  kernel.packet[5] = dm5.template loadPacket<Packet>(k);
1013  kernel.packet[6] = dm6.template loadPacket<Packet>(k);
1014  kernel.packet[7] = dm7.template loadPacket<Packet>(k);
1015 
1016  ptranspose(kernel);
1017 
1018  pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1019  pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1020  pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1021  pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1022  pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
1023  pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
1024  pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
1025  pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
1026  count += 8 * PacketSize;
1027  }
1028  }
1029  for (; k < depth; k++) {
1030  blockB[count + 0] = cj(dm0(k));
1031  blockB[count + 1] = cj(dm1(k));
1032  blockB[count + 2] = cj(dm2(k));
1033  blockB[count + 3] = cj(dm3(k));
1034  blockB[count + 4] = cj(dm4(k));
1035  blockB[count + 5] = cj(dm5(k));
1036  blockB[count + 6] = cj(dm6(k));
1037  blockB[count + 7] = cj(dm7(k));
1038  count += 8;
1039  }
1040  // skip what we have after
1041  if (PanelMode) count += 8 * (stride - offset - depth);
1042  }
1043  }
1044 
1045  if (nr >= 4) {
1046  for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1047  // skip what we have before
1048  if (PanelMode) count += 4 * offset;
1049  const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1050  const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1051  const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1052  const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1053 
1054  Index k = 0;
1055  if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ??
1056  {
1057  for (; k < peeled_k; k += PacketSize) {
1058  PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
1059  kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1060  kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
1061  kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
1062  kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
1063  ptranspose(kernel);
1064  pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1065  pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1066  pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1067  pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1068  count += 4 * PacketSize;
1069  }
1070  }
1071  for (; k < depth; k++) {
1072  blockB[count + 0] = cj(dm0(k));
1073  blockB[count + 1] = cj(dm1(k));
1074  blockB[count + 2] = cj(dm2(k));
1075  blockB[count + 3] = cj(dm3(k));
1076  count += 4;
1077  }
1078  // skip what we have after
1079  if (PanelMode) count += 4 * (stride - offset - depth);
1080  }
1081  }
1082 
1083  // copy the remaining columns one at a time (nr==1)
1084  for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1085  if (PanelMode) count += offset;
1086  const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
1087  for (Index k = 0; k < depth; k++) {
1088  blockB[count] = cj(dm0(k));
1089  count += 1;
1090  }
1091  if (PanelMode) count += (stride - offset - depth);
1092  }
1093 }
1094 
1095 template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
1096 struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, RowMajor, Conjugate, PanelMode> {
1097  typedef typename packet_traits<Scalar>::type Packet;
1098  typedef typename unpacket_traits<Packet>::half HalfPacket;
1099  typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
1100  typedef typename DataMapper::LinearMapper LinearMapper;
1101  enum {
1102  PacketSize = packet_traits<Scalar>::size,
1103  HalfPacketSize = unpacket_traits<HalfPacket>::size,
1104  QuarterPacketSize = unpacket_traits<QuarterPacket>::size
1105  };
1106  EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
1107  Index offset = 0) {
1108  constexpr int nr = 8;
1109  EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
1110  EIGEN_UNUSED_VARIABLE(stride);
1111  EIGEN_UNUSED_VARIABLE(offset);
1112  eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
1113  const bool HasHalf = (int)HalfPacketSize < (int)PacketSize;
1114  const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize;
1115  conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
1116  Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1117  Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
1118  Index count = 0;
1119 
1120  if (nr >= 8) {
1121  for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1122  // skip what we have before
1123  if (PanelMode) count += 8 * offset;
1124  for (Index k = 0; k < depth; k++) {
1125  if (PacketSize == 8) {
1126  // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1127  Packet A = rhs.template loadPacket<Packet>(k, j2);
1128  pstoreu(blockB + count, cj.pconj(A));
1129  } else if (HasHalf && HalfPacketSize == 8) {
1130  HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1131  pstoreu(blockB + count, cj.pconj(A));
1132  } else if (HasQuarter && QuarterPacketSize == 8) {
1133  QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1134  pstoreu(blockB + count, cj.pconj(A));
1135  } else if (PacketSize == 4) {
1136  // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1137  // Packet B = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2 + PacketSize]);
1138  Packet A = rhs.template loadPacket<Packet>(k, j2);
1139  Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
1140  pstoreu(blockB + count, cj.pconj(A));
1141  pstoreu(blockB + count + PacketSize, cj.pconj(B));
1142  } else {
1143  // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2];
1144  const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1145  blockB[count + 0] = cj(dm0(0));
1146  blockB[count + 1] = cj(dm0(1));
1147  blockB[count + 2] = cj(dm0(2));
1148  blockB[count + 3] = cj(dm0(3));
1149  blockB[count + 4] = cj(dm0(4));
1150  blockB[count + 5] = cj(dm0(5));
1151  blockB[count + 6] = cj(dm0(6));
1152  blockB[count + 7] = cj(dm0(7));
1153  }
1154  count += 8;
1155  }
1156  // skip what we have after
1157  if (PanelMode) count += 8 * (stride - offset - depth);
1158  }
1159  }
1160 
1161  if (nr >= 4) {
1162  for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1163  // skip what we have before
1164  if (PanelMode) count += 4 * offset;
1165  for (Index k = 0; k < depth; k++) {
1166  if (PacketSize == 4) {
1167  Packet A = rhs.template loadPacket<Packet>(k, j2);
1168  pstoreu(blockB + count, cj.pconj(A));
1169  count += PacketSize;
1170  } else if (HasHalf && HalfPacketSize == 4) {
1171  HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1172  pstoreu(blockB + count, cj.pconj(A));
1173  count += HalfPacketSize;
1174  } else if (HasQuarter && QuarterPacketSize == 4) {
1175  QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1176  pstoreu(blockB + count, cj.pconj(A));
1177  count += QuarterPacketSize;
1178  } else {
1179  const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1180  blockB[count + 0] = cj(dm0(0));
1181  blockB[count + 1] = cj(dm0(1));
1182  blockB[count + 2] = cj(dm0(2));
1183  blockB[count + 3] = cj(dm0(3));
1184  count += 4;
1185  }
1186  }
1187  // skip what we have after
1188  if (PanelMode) count += 4 * (stride - offset - depth);
1189  }
1190  }
1191  // copy the remaining columns one at a time (nr==1)
1192  for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1193  if (PanelMode) count += offset;
1194  for (Index k = 0; k < depth; k++) {
1195  blockB[count] = cj(rhs(k, j2));
1196  count += 1;
1197  }
1198  if (PanelMode) count += stride - offset - depth;
1199  }
1200  }
1201 };
1202 
1203 template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1204 struct gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs> {
1206  void operator()(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth,
1207  Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
1208  Index offsetB = 0);
1209 };
1210 
1211 template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1213  const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols,
1214  Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
1215  if (res.incr() == 1) {
1216  if (alpha == 1) {
1217  gemm_kern_avx512<Scalar, mr, 8, true, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1218  (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1219  strideB, offsetA, offsetB);
1220  } else {
1221  gemm_kern_avx512<Scalar, mr, 8, false, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1222  (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1223  strideB, offsetA, offsetB);
1224  }
1225  } else {
1226  if (alpha == 1) {
1227  gemm_kern_avx512<Scalar, mr, 8, true, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1228  (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1229  strideB, offsetA, offsetB);
1230  } else {
1231  gemm_kern_avx512<Scalar, mr, 8, false, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1232  (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1233  strideB, offsetA, offsetB);
1234  }
1235  }
1236 }
1237 #endif // EIGEN_USE_AVX512_GEMM_KERNELS
1238 
1239 } // namespace internal
1240 } // namespace Eigen
1241 
1242 #undef SECOND_FETCH
1243 
1244 #endif // EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
Matrix3f m
Array< int, 3, 1 > b
solver compute(A)
int n
MatrixXcf A
Array33i c
#define SECOND_FETCH
Definition: GemmKernel.h:27
MatrixXf B
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_ASM_COMMENT(X)
Definition: Macros.h:963
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:836
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:957
#define EIGEN_DONT_INLINE
Definition: Macros.h:844
#define eigen_assert(x)
Definition: Macros.h:902
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
@ ColMajor
Definition: Constants.h:321
@ RowMajor
Definition: Constants.h:323
bfloat16 pow(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:626
Packet4d ploadu< Packet4d >(const double *from)
Packet padd(const Packet &a, const Packet &b)
Packet8f pzero(const Packet8f &)
Packet8d pload1< Packet8d >(const double *from)
void pstores(Scalar *to, const Packet &from)
void pstorel(Scalar *to, const Packet &from)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
void pstoreu(Scalar *to, const Packet &from)
Packet padds(const Packet &a, const Packet &b)
Packet pmul(const Packet &a, const Packet &b)
void pscatter(Scalar *to, const Packet &from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Packet4f ploadu< Packet4f >(const float *from)
Packet pgather(const Packet &src, const Scalar *from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c, Index ldc, Index inc=1, Index a_stride=-1, Index b_stride=-1, Index a_off=0, Index b_off=0)
Definition: GemmKernel.h:936
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
std::ptrdiff_t j