PacketMathFP16.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 //
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_PACKET_MATH_FP16_AVX512_H
11 #define EIGEN_PACKET_MATH_FP16_AVX512_H
12 
13 #include "../../InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 namespace internal {
18 
19 typedef __m512h Packet32h;
20 typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
21 typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
22 
23 template <>
24 struct is_arithmetic<Packet8h> {
25  enum { value = true };
26 };
27 
28 template <>
29 struct packet_traits<half> : default_packet_traits {
30  typedef Packet32h type;
31  typedef Packet16h half;
32  enum {
33  Vectorizable = 1,
34  AlignedOnScalar = 1,
35  size = 32,
36 
37  HasCmp = 1,
38  HasAdd = 1,
39  HasSub = 1,
40  HasMul = 1,
41  HasDiv = 1,
42  HasNegate = 1,
43  HasAbs = 1,
44  HasAbs2 = 0,
45  HasMin = 1,
46  HasMax = 1,
47  HasConj = 1,
48  HasSetLinear = 0,
49  HasLog = 1,
50  HasLog1p = 1,
51  HasExp = 1,
52  HasExpm1 = 1,
53  HasSqrt = 1,
54  HasRsqrt = 1,
55  // These ones should be implemented in future
56  HasBessel = 0,
57  HasNdtri = 0,
58  HasSin = EIGEN_FAST_MATH,
59  HasCos = EIGEN_FAST_MATH,
60  HasTanh = EIGEN_FAST_MATH,
61  HasErf = 0, // EIGEN_FAST_MATH,
62  HasBlend = 0,
63  HasRound = 1,
64  HasFloor = 1,
65  HasCeil = 1,
66  HasRint = 1
67  };
68 };
69 
70 template <>
71 struct unpacket_traits<Packet32h> {
72  typedef Eigen::half type;
73  typedef Packet16h half;
74  enum {
75  size = 32,
76  alignment = Aligned64,
77  vectorizable = true,
78  masked_load_available = false,
79  masked_store_available = false
80  };
81 };
82 
83 template <>
84 struct unpacket_traits<Packet16h> {
85  typedef Eigen::half type;
86  typedef Packet8h half;
87  enum {
88  size = 16,
89  alignment = Aligned32,
90  vectorizable = true,
91  masked_load_available = false,
92  masked_store_available = false
93  };
94 };
95 
96 template <>
97 struct unpacket_traits<Packet8h> {
98  typedef Eigen::half type;
99  typedef Packet8h half;
100  enum {
101  size = 8,
102  alignment = Aligned16,
103  vectorizable = true,
104  masked_load_available = false,
105  masked_store_available = false
106  };
107 };
108 
109 // Memory functions
110 
111 // pset1
112 
113 template <>
114 EIGEN_STRONG_INLINE Packet32h pset1<Packet32h>(const Eigen::half& from) {
115  return _mm512_set1_ph(static_cast<_Float16>(from));
116 }
117 
118 // pset1frombits
119 template <>
120 EIGEN_STRONG_INLINE Packet32h pset1frombits<Packet32h>(unsigned short from) {
121  return _mm512_castsi512_ph(_mm512_set1_epi16(from));
122 }
123 
124 // pfirst
125 
126 template <>
127 EIGEN_STRONG_INLINE Eigen::half pfirst<Packet32h>(const Packet32h& from) {
128 #ifdef EIGEN_VECTORIZE_AVX512DQ
130  static_cast<unsigned short>(_mm256_extract_epi16(_mm512_extracti32x8_epi32(_mm512_castph_si512(from), 0), 0)));
131 #else
132  Eigen::half dest[32];
133  _mm512_storeu_ph(dest, from);
134  return dest[0];
135 #endif
136 }
137 
138 // pload
139 
140 template <>
141 EIGEN_STRONG_INLINE Packet32h pload<Packet32h>(const Eigen::half* from) {
142  EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ph(from);
143 }
144 
145 // ploadu
146 
147 template <>
148 EIGEN_STRONG_INLINE Packet32h ploadu<Packet32h>(const Eigen::half* from) {
149  EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ph(from);
150 }
151 
152 // pstore
153 
154 template <>
155 EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet32h& from) {
156  EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ph(to, from);
157 }
158 
159 // pstoreu
160 
161 template <>
162 EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet32h& from) {
163  EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ph(to, from);
164 }
165 
166 // ploaddup
167 template <>
168 EIGEN_STRONG_INLINE Packet32h ploaddup<Packet32h>(const Eigen::half* from) {
169  __m512h a = _mm512_castph256_ph512(_mm256_loadu_ph(from));
170  return _mm512_permutexvar_ph(_mm512_set_epi16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6,
171  5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0),
172  a);
173 }
174 
175 // ploadquad
176 template <>
177 EIGEN_STRONG_INLINE Packet32h ploadquad<Packet32h>(const Eigen::half* from) {
178  __m512h a = _mm512_castph128_ph512(_mm_loadu_ph(from));
179  return _mm512_permutexvar_ph(
180  _mm512_set_epi16(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0),
181  a);
182 }
183 
184 // pabs
185 
186 template <>
187 EIGEN_STRONG_INLINE Packet32h pabs<Packet32h>(const Packet32h& a) {
188  return _mm512_abs_ph(a);
189 }
190 
191 // psignbit
192 
193 template <>
194 EIGEN_STRONG_INLINE Packet32h psignbit<Packet32h>(const Packet32h& a) {
195  return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15));
196 }
197 
198 // pmin
199 
200 template <>
201 EIGEN_STRONG_INLINE Packet32h pmin<Packet32h>(const Packet32h& a, const Packet32h& b) {
202  return _mm512_min_ph(a, b);
203 }
204 
205 // pmax
206 
207 template <>
208 EIGEN_STRONG_INLINE Packet32h pmax<Packet32h>(const Packet32h& a, const Packet32h& b) {
209  return _mm512_max_ph(a, b);
210 }
211 
212 // plset
213 template <>
214 EIGEN_STRONG_INLINE Packet32h plset<Packet32h>(const half& a) {
215  return _mm512_add_ph(_mm512_set1_ph(a),
216  _mm512_set_ph(31.0f, 30.0f, 29.0f, 28.0f, 27.0f, 26.0f, 25.0f, 24.0f, 23.0f, 22.0f, 21.0f, 20.0f,
217  19.0f, 18.0f, 17.0f, 16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f,
218  7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f));
219 }
220 
221 // por
222 
223 template <>
224 EIGEN_STRONG_INLINE Packet32h por(const Packet32h& a, const Packet32h& b) {
225  return _mm512_castsi512_ph(_mm512_or_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
226 }
227 
228 // pxor
229 
230 template <>
231 EIGEN_STRONG_INLINE Packet32h pxor(const Packet32h& a, const Packet32h& b) {
232  return _mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
233 }
234 
235 // pand
236 
237 template <>
238 EIGEN_STRONG_INLINE Packet32h pand(const Packet32h& a, const Packet32h& b) {
239  return _mm512_castsi512_ph(_mm512_and_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
240 }
241 
242 // pandnot
243 
244 template <>
245 EIGEN_STRONG_INLINE Packet32h pandnot(const Packet32h& a, const Packet32h& b) {
246  return _mm512_castsi512_ph(_mm512_andnot_si512(_mm512_castph_si512(b), _mm512_castph_si512(a)));
247 }
248 
249 // pselect
250 
251 template <>
252 EIGEN_DEVICE_FUNC inline Packet32h pselect(const Packet32h& mask, const Packet32h& a, const Packet32h& b) {
253  __mmask32 mask32 = _mm512_cmp_epi16_mask(_mm512_castph_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
254  return _mm512_mask_blend_ph(mask32, a, b);
255 }
256 
257 // pcmp_eq
258 
259 template <>
260 EIGEN_STRONG_INLINE Packet32h pcmp_eq(const Packet32h& a, const Packet32h& b) {
261  __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_EQ_OQ);
262  return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
263 }
264 
265 // pcmp_le
266 
267 template <>
268 EIGEN_STRONG_INLINE Packet32h pcmp_le(const Packet32h& a, const Packet32h& b) {
269  __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LE_OQ);
270  return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
271 }
272 
273 // pcmp_lt
274 
275 template <>
276 EIGEN_STRONG_INLINE Packet32h pcmp_lt(const Packet32h& a, const Packet32h& b) {
277  __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LT_OQ);
278  return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
279 }
280 
281 // pcmp_lt_or_nan
282 
283 template <>
284 EIGEN_STRONG_INLINE Packet32h pcmp_lt_or_nan(const Packet32h& a, const Packet32h& b) {
285  __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_NGE_UQ);
286  return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi16(0), mask, 0xffffu));
287 }
288 
289 // padd
290 
291 template <>
292 EIGEN_STRONG_INLINE Packet32h padd<Packet32h>(const Packet32h& a, const Packet32h& b) {
293  return _mm512_add_ph(a, b);
294 }
295 
296 template <>
297 EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
298  return _mm256_castph_si256(_mm256_add_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
299 }
300 
301 template <>
302 EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
303  return _mm_castph_si128(_mm_add_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
304 }
305 
306 // psub
307 
308 template <>
309 EIGEN_STRONG_INLINE Packet32h psub<Packet32h>(const Packet32h& a, const Packet32h& b) {
310  return _mm512_sub_ph(a, b);
311 }
312 
313 template <>
314 EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
315  return _mm256_castph_si256(_mm256_sub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
316 }
317 
318 template <>
319 EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
320  return _mm_castph_si128(_mm_sub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
321 }
322 
323 // pmul
324 
325 template <>
326 EIGEN_STRONG_INLINE Packet32h pmul<Packet32h>(const Packet32h& a, const Packet32h& b) {
327  return _mm512_mul_ph(a, b);
328 }
329 
330 template <>
331 EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
332  return _mm256_castph_si256(_mm256_mul_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
333 }
334 
335 template <>
336 EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
337  return _mm_castph_si128(_mm_mul_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
338 }
339 
340 // pdiv
341 
342 template <>
343 EIGEN_STRONG_INLINE Packet32h pdiv<Packet32h>(const Packet32h& a, const Packet32h& b) {
344  return _mm512_div_ph(a, b);
345 }
346 
347 template <>
348 EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
349  return _mm256_castph_si256(_mm256_div_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
350 }
351 
352 template <>
353 EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
354  return _mm_castph_si128(_mm_div_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
355 }
356 
357 // pround
358 
359 template <>
360 EIGEN_STRONG_INLINE Packet32h pround<Packet32h>(const Packet32h& a) {
361  // Work-around for default std::round rounding mode.
362 
363  // Mask for the sign bit
364  const Packet32h signMask = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x8000u));
365  // The largest half-preicision float less than 0.5
366  const Packet32h prev0dot5 = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x37FFu));
367 
368  return _mm512_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
369 }
370 
371 // print
372 
373 template <>
374 EIGEN_STRONG_INLINE Packet32h print<Packet32h>(const Packet32h& a) {
375  return _mm512_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
376 }
377 
378 // pceil
379 
380 template <>
381 EIGEN_STRONG_INLINE Packet32h pceil<Packet32h>(const Packet32h& a) {
382  return _mm512_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
383 }
384 
385 // pfloor
386 
387 template <>
388 EIGEN_STRONG_INLINE Packet32h pfloor<Packet32h>(const Packet32h& a) {
389  return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
390 }
391 
392 // predux
393 template <>
394 EIGEN_STRONG_INLINE half predux<Packet32h>(const Packet32h& a) {
395  return (half)_mm512_reduce_add_ph(a);
396 }
397 
398 template <>
399 EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& a) {
400  return (half)_mm256_reduce_add_ph(_mm256_castsi256_ph(a));
401 }
402 
403 template <>
404 EIGEN_STRONG_INLINE half predux<Packet8h>(const Packet8h& a) {
405  return (half)_mm_reduce_add_ph(_mm_castsi128_ph(a));
406 }
407 
408 // predux_half_dowto4
409 template <>
410 EIGEN_STRONG_INLINE Packet16h predux_half_dowto4<Packet32h>(const Packet32h& a) {
411 #ifdef EIGEN_VECTORIZE_AVX512DQ
412  __m256i lowHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 0));
413  __m256i highHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 1));
414 
415  return Packet16h(padd<Packet16h>(lowHalf, highHalf));
416 #else
417  Eigen::half data[32];
418  _mm512_storeu_ph(data, a);
419 
420  __m256i lowHalf = _mm256_castph_si256(_mm256_loadu_ph(data));
421  __m256i highHalf = _mm256_castph_si256(_mm256_loadu_ph(data + 16));
422 
423  return Packet16h(padd<Packet16h>(lowHalf, highHalf));
424 #endif
425 }
426 
427 // predux_max
428 
429 // predux_min
430 
431 // predux_mul
432 
433 #ifdef EIGEN_VECTORIZE_FMA
434 
435 // pmadd
436 
437 template <>
438 EIGEN_STRONG_INLINE Packet32h pmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
439  return _mm512_fmadd_ph(a, b, c);
440 }
441 
442 template <>
443 EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
444  return _mm256_castph_si256(_mm256_fmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
445 }
446 
447 template <>
448 EIGEN_STRONG_INLINE Packet8h pmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
449  return _mm_castph_si128(_mm_fmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
450 }
451 
452 // pmsub
453 
454 template <>
455 EIGEN_STRONG_INLINE Packet32h pmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
456  return _mm512_fmsub_ph(a, b, c);
457 }
458 
459 template <>
460 EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
461  return _mm256_castph_si256(_mm256_fmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
462 }
463 
464 template <>
465 EIGEN_STRONG_INLINE Packet8h pmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
466  return _mm_castph_si128(_mm_fmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
467 }
468 
469 // pnmadd
470 
471 template <>
472 EIGEN_STRONG_INLINE Packet32h pnmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
473  return _mm512_fnmadd_ph(a, b, c);
474 }
475 
476 template <>
477 EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
478  return _mm256_castph_si256(_mm256_fnmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
479 }
480 
481 template <>
482 EIGEN_STRONG_INLINE Packet8h pnmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
483  return _mm_castph_si128(_mm_fnmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
484 }
485 
486 // pnmsub
487 
488 template <>
489 EIGEN_STRONG_INLINE Packet32h pnmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
490  return _mm512_fnmsub_ph(a, b, c);
491 }
492 
493 template <>
494 EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
495  return _mm256_castph_si256(_mm256_fnmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
496 }
497 
498 template <>
499 EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
500  return _mm_castph_si128(_mm_fnmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
501 }
502 
503 #endif
504 
505 // pnegate
506 
507 template <>
508 EIGEN_STRONG_INLINE Packet32h pnegate<Packet32h>(const Packet32h& a) {
509  return _mm512_sub_ph(_mm512_set1_ph(0.0), a);
510 }
511 
512 // pconj
513 
514 template <>
515 EIGEN_STRONG_INLINE Packet32h pconj<Packet32h>(const Packet32h& a) {
516  return a;
517 }
518 
519 // psqrt
520 
521 template <>
522 EIGEN_STRONG_INLINE Packet32h psqrt<Packet32h>(const Packet32h& a) {
523  return _mm512_sqrt_ph(a);
524 }
525 
526 // prsqrt
527 
528 template <>
529 EIGEN_STRONG_INLINE Packet32h prsqrt<Packet32h>(const Packet32h& a) {
530  return _mm512_rsqrt_ph(a);
531 }
532 
533 // preciprocal
534 
535 template <>
536 EIGEN_STRONG_INLINE Packet32h preciprocal<Packet32h>(const Packet32h& a) {
537  return _mm512_rcp_ph(a);
538 }
539 
540 // ptranspose
541 
542 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 32>& a) {
543  __m512i t[32];
544 
546  for (int i = 0; i < 16; i++) {
547  t[2 * i] = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
548  t[2 * i + 1] =
549  _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
550  }
551 
552  __m512i p[32];
553 
555  for (int i = 0; i < 8; i++) {
556  p[4 * i] = _mm512_unpacklo_epi32(t[4 * i], t[4 * i + 2]);
557  p[4 * i + 1] = _mm512_unpackhi_epi32(t[4 * i], t[4 * i + 2]);
558  p[4 * i + 2] = _mm512_unpacklo_epi32(t[4 * i + 1], t[4 * i + 3]);
559  p[4 * i + 3] = _mm512_unpackhi_epi32(t[4 * i + 1], t[4 * i + 3]);
560  }
561 
562  __m512i q[32];
563 
565  for (int i = 0; i < 4; i++) {
566  q[8 * i] = _mm512_unpacklo_epi64(p[8 * i], p[8 * i + 4]);
567  q[8 * i + 1] = _mm512_unpackhi_epi64(p[8 * i], p[8 * i + 4]);
568  q[8 * i + 2] = _mm512_unpacklo_epi64(p[8 * i + 1], p[8 * i + 5]);
569  q[8 * i + 3] = _mm512_unpackhi_epi64(p[8 * i + 1], p[8 * i + 5]);
570  q[8 * i + 4] = _mm512_unpacklo_epi64(p[8 * i + 2], p[8 * i + 6]);
571  q[8 * i + 5] = _mm512_unpackhi_epi64(p[8 * i + 2], p[8 * i + 6]);
572  q[8 * i + 6] = _mm512_unpacklo_epi64(p[8 * i + 3], p[8 * i + 7]);
573  q[8 * i + 7] = _mm512_unpackhi_epi64(p[8 * i + 3], p[8 * i + 7]);
574  }
575 
576  __m512i f[32];
577 
578 #define PACKET32H_TRANSPOSE_HELPER(X, Y) \
579  do { \
580  f[Y * 8] = _mm512_inserti32x4(f[Y * 8], _mm512_extracti32x4_epi32(q[X * 8], Y), X); \
581  f[Y * 8 + 1] = _mm512_inserti32x4(f[Y * 8 + 1], _mm512_extracti32x4_epi32(q[X * 8 + 1], Y), X); \
582  f[Y * 8 + 2] = _mm512_inserti32x4(f[Y * 8 + 2], _mm512_extracti32x4_epi32(q[X * 8 + 2], Y), X); \
583  f[Y * 8 + 3] = _mm512_inserti32x4(f[Y * 8 + 3], _mm512_extracti32x4_epi32(q[X * 8 + 3], Y), X); \
584  f[Y * 8 + 4] = _mm512_inserti32x4(f[Y * 8 + 4], _mm512_extracti32x4_epi32(q[X * 8 + 4], Y), X); \
585  f[Y * 8 + 5] = _mm512_inserti32x4(f[Y * 8 + 5], _mm512_extracti32x4_epi32(q[X * 8 + 5], Y), X); \
586  f[Y * 8 + 6] = _mm512_inserti32x4(f[Y * 8 + 6], _mm512_extracti32x4_epi32(q[X * 8 + 6], Y), X); \
587  f[Y * 8 + 7] = _mm512_inserti32x4(f[Y * 8 + 7], _mm512_extracti32x4_epi32(q[X * 8 + 7], Y), X); \
588  } while (false);
589 
594 
601 
608 
609 #undef PACKET32H_TRANSPOSE_HELPER
610 
612  for (int i = 0; i < 32; i++) {
613  a.packet[i] = _mm512_castsi512_ph(f[i]);
614  }
615 }
616 
617 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 4>& a) {
618  __m512i p0, p1, p2, p3, t0, t1, t2, t3, a0, a1, a2, a3;
619  t0 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
620  t1 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
621  t2 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
622  t3 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
623 
624  p0 = _mm512_unpacklo_epi32(t0, t2);
625  p1 = _mm512_unpackhi_epi32(t0, t2);
626  p2 = _mm512_unpacklo_epi32(t1, t3);
627  p3 = _mm512_unpackhi_epi32(t1, t3);
628 
629  a0 = p0;
630  a1 = p1;
631  a2 = p2;
632  a3 = p3;
633 
634  a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p1, 0), 1);
635  a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p0, 1), 0);
636 
637  a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p2, 0), 2);
638  a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p0, 2), 0);
639 
640  a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p3, 0), 3);
641  a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p0, 3), 0);
642 
643  a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p2, 1), 2);
644  a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p1, 2), 1);
645 
646  a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p3, 2), 3);
647  a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p2, 3), 2);
648 
649  a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p3, 1), 3);
650  a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p1, 3), 1);
651 
652  a.packet[0] = _mm512_castsi512_ph(a0);
653  a.packet[1] = _mm512_castsi512_ph(a1);
654  a.packet[2] = _mm512_castsi512_ph(a2);
655  a.packet[3] = _mm512_castsi512_ph(a3);
656 }
657 
658 // preverse
659 
660 template <>
661 EIGEN_STRONG_INLINE Packet32h preverse(const Packet32h& a) {
662  return _mm512_permutexvar_ph(_mm512_set_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
663  20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31),
664  a);
665 }
666 
667 // pscatter
668 
669 template <>
670 EIGEN_STRONG_INLINE void pscatter<half, Packet32h>(half* to, const Packet32h& from, Index stride) {
671  EIGEN_ALIGN64 half aux[32];
672  pstore(aux, from);
673 
675  for (int i = 0; i < 32; i++) {
676  to[stride * i] = aux[i];
677  }
678 }
679 
680 // pgather
681 
682 template <>
683 EIGEN_STRONG_INLINE Packet32h pgather<Eigen::half, Packet32h>(const Eigen::half* from, Index stride) {
684  return _mm512_castsi512_ph(_mm512_set_epi16(
685  from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x, from[27 * stride].x,
686  from[26 * stride].x, from[25 * stride].x, from[24 * stride].x, from[23 * stride].x, from[22 * stride].x,
687  from[21 * stride].x, from[20 * stride].x, from[19 * stride].x, from[18 * stride].x, from[17 * stride].x,
688  from[16 * stride].x, from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
689  from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x, from[7 * stride].x,
690  from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x, from[2 * stride].x,
691  from[1 * stride].x, from[0 * stride].x));
692 }
693 
694 template <>
695 EIGEN_STRONG_INLINE Packet16h pcos<Packet16h>(const Packet16h&);
696 template <>
697 EIGEN_STRONG_INLINE Packet16h psin<Packet16h>(const Packet16h&);
698 template <>
699 EIGEN_STRONG_INLINE Packet16h plog<Packet16h>(const Packet16h&);
700 template <>
701 EIGEN_STRONG_INLINE Packet16h plog2<Packet16h>(const Packet16h&);
702 template <>
703 EIGEN_STRONG_INLINE Packet16h plog1p<Packet16h>(const Packet16h&);
704 template <>
705 EIGEN_STRONG_INLINE Packet16h pexp<Packet16h>(const Packet16h&);
706 template <>
707 EIGEN_STRONG_INLINE Packet16h pexpm1<Packet16h>(const Packet16h&);
708 template <>
709 EIGEN_STRONG_INLINE Packet16h ptanh<Packet16h>(const Packet16h&);
710 template <>
711 EIGEN_STRONG_INLINE Packet16h pfrexp<Packet16h>(const Packet16h&, Packet16h&);
712 template <>
713 EIGEN_STRONG_INLINE Packet16h pldexp<Packet16h>(const Packet16h&, const Packet16h&);
714 
715 EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
716  __m512d result = _mm512_undefined_pd();
717  result = _mm512_insertf64x4(result, _mm256_castsi256_pd(a), 0);
718  result = _mm512_insertf64x4(result, _mm256_castsi256_pd(b), 1);
719  return _mm512_castpd_ph(result);
720 }
721 
722 EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
723  a = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 0));
724  b = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 1));
725 }
726 
727 // psin
728 template <>
729 EIGEN_STRONG_INLINE Packet32h psin<Packet32h>(const Packet32h& a) {
730  Packet16h low;
731  Packet16h high;
732  extract2Packet16h(a, low, high);
733 
734  Packet16h lowOut = psin(low);
735  Packet16h highOut = psin(high);
736 
737  return combine2Packet16h(lowOut, highOut);
738 }
739 
740 // pcos
741 template <>
742 EIGEN_STRONG_INLINE Packet32h pcos<Packet32h>(const Packet32h& a) {
743  Packet16h low;
744  Packet16h high;
745  extract2Packet16h(a, low, high);
746 
747  Packet16h lowOut = pcos(low);
748  Packet16h highOut = pcos(high);
749 
750  return combine2Packet16h(lowOut, highOut);
751 }
752 
753 // plog
754 template <>
755 EIGEN_STRONG_INLINE Packet32h plog<Packet32h>(const Packet32h& a) {
756  Packet16h low;
757  Packet16h high;
758  extract2Packet16h(a, low, high);
759 
760  Packet16h lowOut = plog(low);
761  Packet16h highOut = plog(high);
762 
763  return combine2Packet16h(lowOut, highOut);
764 }
765 
766 // plog2
767 template <>
768 EIGEN_STRONG_INLINE Packet32h plog2<Packet32h>(const Packet32h& a) {
769  Packet16h low;
770  Packet16h high;
771  extract2Packet16h(a, low, high);
772 
773  Packet16h lowOut = plog2(low);
774  Packet16h highOut = plog2(high);
775 
776  return combine2Packet16h(lowOut, highOut);
777 }
778 
779 // plog1p
780 template <>
781 EIGEN_STRONG_INLINE Packet32h plog1p<Packet32h>(const Packet32h& a) {
782  Packet16h low;
783  Packet16h high;
784  extract2Packet16h(a, low, high);
785 
786  Packet16h lowOut = plog1p(low);
787  Packet16h highOut = plog1p(high);
788 
789  return combine2Packet16h(lowOut, highOut);
790 }
791 
792 // pexp
793 template <>
794 EIGEN_STRONG_INLINE Packet32h pexp<Packet32h>(const Packet32h& a) {
795  Packet16h low;
796  Packet16h high;
797  extract2Packet16h(a, low, high);
798 
799  Packet16h lowOut = pexp(low);
800  Packet16h highOut = pexp(high);
801 
802  return combine2Packet16h(lowOut, highOut);
803 }
804 
805 // pexpm1
806 template <>
807 EIGEN_STRONG_INLINE Packet32h pexpm1<Packet32h>(const Packet32h& a) {
808  Packet16h low;
809  Packet16h high;
810  extract2Packet16h(a, low, high);
811 
812  Packet16h lowOut = pexpm1(low);
813  Packet16h highOut = pexpm1(high);
814 
815  return combine2Packet16h(lowOut, highOut);
816 }
817 
818 // ptanh
819 template <>
820 EIGEN_STRONG_INLINE Packet32h ptanh<Packet32h>(const Packet32h& a) {
821  Packet16h low;
822  Packet16h high;
823  extract2Packet16h(a, low, high);
824 
825  Packet16h lowOut = ptanh(low);
826  Packet16h highOut = ptanh(high);
827 
828  return combine2Packet16h(lowOut, highOut);
829 }
830 
831 // pfrexp
832 template <>
833 EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
834  Packet16h low;
835  Packet16h high;
836  extract2Packet16h(a, low, high);
837 
838  Packet16h exp1 = _mm256_undefined_si256();
839  Packet16h exp2 = _mm256_undefined_si256();
840 
841  Packet16h lowOut = pfrexp(low, exp1);
842  Packet16h highOut = pfrexp(high, exp2);
843 
844  exponent = combine2Packet16h(exp1, exp2);
845 
846  return combine2Packet16h(lowOut, highOut);
847 }
848 
849 // pldexp
850 template <>
851 EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
852  Packet16h low;
853  Packet16h high;
854  extract2Packet16h(a, low, high);
855 
856  Packet16h exp1;
857  Packet16h exp2;
858  extract2Packet16h(exponent, exp1, exp2);
859 
860  Packet16h lowOut = pldexp(low, exp1);
861  Packet16h highOut = pldexp(high, exp2);
862 
863  return combine2Packet16h(lowOut, highOut);
864 }
865 
866 } // end namespace internal
867 } // end namespace Eigen
868 
869 #endif // EIGEN_PACKET_MATH_FP16_AVX512_H
Array< int, 3, 1 > b
#define EIGEN_ALIGN64
Array33i c
#define EIGEN_DEBUG_ALIGNED_STORE
#define EIGEN_DEBUG_ALIGNED_LOAD
#define EIGEN_DEBUG_UNALIGNED_STORE
#define EIGEN_DEBUG_UNALIGNED_LOAD
#define EIGEN_UNROLL_LOOP
Definition: Macros.h:1290
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:883
#define EIGEN_FAST_MATH
Definition: Macros.h:50
int data[]
Vector3f p0
Vector3f p1
#define PACKET32H_TRANSPOSE_HELPER(X, Y)
float * p
@ Aligned64
Definition: Constants.h:239
@ Aligned32
Definition: Constants.h:238
@ Aligned16
Definition: Constants.h:237
EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x)
Definition: Half.h:551
void pscatter< half, Packet32h >(half *to, const Packet32h &from, Index stride)
Packet16h pfrexp< Packet16h >(const Packet16h &, Packet16h &)
Packet16h pexpm1< Packet16h >(const Packet16h &)
Packet pnmsub(const Packet &a, const Packet &b, const Packet &c)
Packet32h ploadu< Packet32h >(const Eigen::half *from)
Packet32h plog1p< Packet32h >(const Packet32h &a)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexpm1(const Packet &a)
Packet padd(const Packet &a, const Packet &b)
Packet8h pdiv< Packet8h >(const Packet8h &a, const Packet8h &b)
void pstore(Scalar *to, const Packet &from)
Packet32h pceil< Packet32h >(const Packet32h &a)
Packet32h pround< Packet32h >(const Packet32h &a)
Packet32h plset< Packet32h >(const half &a)
void pstoreu< half >(Eigen::half *to, const Packet16h &from)
Packet32h pmul< Packet32h >(const Packet32h &a, const Packet32h &b)
Packet4f pcmp_lt_or_nan(const Packet4f &a, const Packet4f &b)
Packet16h pldexp< Packet16h >(const Packet16h &, const Packet16h &)
half predux< Packet32h >(const Packet32h &a)
Packet32h prsqrt< Packet32h >(const Packet32h &a)
Packet8h padd< Packet8h >(const Packet8h &a, const Packet8h &b)
Packet16h plog1p< Packet16h >(const Packet16h &)
Packet32h pexpm1< Packet32h >(const Packet32h &a)
Packet32h print< Packet32h >(const Packet32h &a)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2(const Packet &a)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog(const Packet &a)
Packet8h pandnot(const Packet8h &a, const Packet8h &b)
Packet32h ptanh< Packet32h >(const Packet32h &a)
Packet16h plog2< Packet16h >(const Packet16h &)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp(const Packet &a)
void extract2Packet16h(const Packet32h &x, Packet16h &a, Packet16h &b)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos(const Packet &a)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin(const Packet &a)
Packet2cf pcmp_eq(const Packet2cf &a, const Packet2cf &b)
Packet32h pset1< Packet32h >(const Eigen::half &from)
Packet16h plog< Packet16h >(const Packet16h &)
Packet4f pselect(const Packet4f &mask, const Packet4f &a, const Packet4f &b)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Packet32h pdiv< Packet32h >(const Packet32h &a, const Packet32h &b)
Packet32h pload< Packet32h >(const Eigen::half *from)
Packet pmsub(const Packet &a, const Packet &b, const Packet &c)
Packet16h ptanh< Packet16h >(const Packet16h &)
Packet32h pnegate< Packet32h >(const Packet32h &a)
Packet32h ploaddup< Packet32h >(const Eigen::half *from)
Packet8h pfrexp(const Packet8h &a, Packet8h &exponent)
Eigen::half predux< Packet8h >(const Packet8h &a)
Packet8h psub< Packet8h >(const Packet8h &a, const Packet8h &b)
Packet32h psin< Packet32h >(const Packet32h &a)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptanh(const Packet &a)
Packet32h pexp< Packet32h >(const Packet32h &a)
Packet pnmadd(const Packet &a, const Packet &b, const Packet &c)
Packet32h pmin< Packet32h >(const Packet32h &a, const Packet32h &b)
void pstore< half >(Eigen::half *to, const Packet16h &from)
Packet32h plog2< Packet32h >(const Packet32h &a)
Packet16h padd< Packet16h >(const Packet16h &a, const Packet16h &b)
Packet8h pand(const Packet8h &a, const Packet8h &b)
Packet8h pldexp(const Packet8h &a, const Packet8h &exponent)
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog1p(const Packet &a)
Packet16h pmul< Packet16h >(const Packet16h &a, const Packet16h &b)
Packet16h predux_half_dowto4< Packet32h >(const Packet32h &a)
half predux< Packet16h >(const Packet16h &from)
Packet16h pcos< Packet16h >(const Packet16h &)
Packet8h pxor(const Packet8h &a, const Packet8h &b)
Packet32h psqrt< Packet32h >(const Packet32h &a)
Packet32h pmax< Packet32h >(const Packet32h &a, const Packet32h &b)
Packet16h psub< Packet16h >(const Packet16h &a, const Packet16h &b)
Packet16h psin< Packet16h >(const Packet16h &)
Packet32h padd< Packet32h >(const Packet32h &a, const Packet32h &b)
Packet32h pldexp< Packet32h >(const Packet32h &a, const Packet32h &exponent)
Packet2cf preverse(const Packet2cf &a)
Packet8h pmul< Packet8h >(const Packet8h &a, const Packet8h &b)
Packet8h por(const Packet8h &a, const Packet8h &b)
Packet32h plog< Packet32h >(const Packet32h &a)
Packet4i pcmp_lt(const Packet4i &a, const Packet4i &b)
Packet32h pcos< Packet32h >(const Packet32h &a)
Packet32h pconj< Packet32h >(const Packet32h &a)
Packet32h psignbit< Packet32h >(const Packet32h &a)
Eigen::half pfirst< Packet32h >(const Packet32h &from)
Packet32h pset1frombits< Packet32h >(unsigned short from)
Packet16h pexp< Packet16h >(const Packet16h &)
eigen_packet_wrapper< __m256i, 1 > Packet16h
Packet32h preciprocal< Packet32h >(const Packet32h &a)
Packet32h pfloor< Packet32h >(const Packet32h &a)
Packet32h psub< Packet32h >(const Packet32h &a, const Packet32h &b)
Packet16h pdiv< Packet16h >(const Packet16h &a, const Packet16h &b)
Packet32h ploadquad< Packet32h >(const Eigen::half *from)
eigen_packet_wrapper< __m128i, 2 > Packet8h
Packet32h combine2Packet16h(const Packet16h &a, const Packet16h &b)
Packet4f pcmp_le(const Packet4f &a, const Packet4f &b)
Packet32h pabs< Packet32h >(const Packet32h &a)
Packet32h pfrexp< Packet32h >(const Packet32h &a, Packet32h &exponent)
std::uint16_t uint16_t
Definition: Meta.h:37
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82