10 #ifndef EIGEN_TYPE_CASTING_AVX512_H
11 #define EIGEN_TYPE_CASTING_AVX512_H
13 #include "../../InternalHeaderCheck.h"
20 struct type_casting_traits<float,
bool> {
29 struct type_casting_traits<
bool, float> {
38 __mmask16 mask = _mm512_cmpneq_ps_mask(
a,
pzero(
a));
39 return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
43 return _mm512_cvtepi32_ps(_mm512_and_si512(_mm512_cvtepi8_epi32(
a), _mm512_set1_epi32(1)));
47 return _mm512_cvttps_epi32(
a);
51 return _mm512_cvtepi32_ps(
a);
55 return cat256(_mm512_cvtpd_ps(
a), _mm512_cvtpd_ps(
b));
59 return cat256i(_mm512_cvttpd_epi32(
a), _mm512_cvttpd_epi32(
b));
63 return _mm512_cvtpd_epi32(
a);
66 return _mm512_cvtpd_ps(
a);
70 return _mm512_castps_si512(
a);
74 return _mm512_castsi512_ps(
a);
78 return _mm512_castps_pd(
a);
82 return _mm512_castpd_ps(
a);
86 return _mm512_castps512_ps256(
a);
90 return _mm512_castps512_ps128(
a);
94 return _mm512_castpd512_pd256(
a);
98 return _mm512_castpd512_pd128(
a);
102 return _mm512_castps256_ps512(
a);
106 return _mm512_castps128_ps512(
a);
110 return _mm512_castpd256_pd512(
a);
114 return _mm512_castpd128_pd512(
a);
118 return _mm512_castsi512_si256(
a);
121 return _mm512_castsi512_si128(
a);
125 return _mm256_castsi256_si128(
a);
129 return _mm256_castsi256_si128(
a);
132 #ifndef EIGEN_VECTORIZE_AVX512FP16
135 struct type_casting_traits<
half, float> {
148 struct type_casting_traits<float,
half> {
163 struct type_casting_traits<
bfloat16, float> {
176 struct type_casting_traits<float,
bfloat16> {
188 #ifdef EIGEN_VECTORIZE_AVX512FP16
191 struct type_casting_traits<
half, float> {
200 struct type_casting_traits<float, half> {
208 template<> EIGEN_STRONG_INLINE
Packet16h preinterpret<Packet16h, Packet32h>(
const Packet32h&
a) {
209 return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(
a), 0));
211 template<> EIGEN_STRONG_INLINE
Packet8h preinterpret<Packet8h, Packet32h>(
const Packet32h&
a) {
212 return _mm256_castsi256_si128(preinterpret<Packet16h>(
a));
218 Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(
a), 0));
219 return _mm512_cvtxph_ps(_mm256_castsi256_ph(low));
225 __m512d result = _mm512_undefined_pd();
226 result = _mm512_insertf64x4(result, _mm256_castsi256_pd(_mm512_cvtps_ph(
a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 0);
227 result = _mm512_insertf64x4(result, _mm256_castsi256_pd(_mm512_cvtps_ph(
b, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 1);
228 return _mm512_castpd_ph(result);
234 Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(
a), 0));
235 return _mm256_cvtxph_ps(_mm_castsi128_ph(low));
241 __m256d result = _mm256_undefined_pd();
242 result = _mm256_insertf64x2(result, _mm_castsi128_pd(_mm256_cvtps_ph(
a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 0);
243 result = _mm256_insertf64x2(result, _mm_castsi128_pd(_mm256_cvtps_ph(
b, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 1);
244 return _mm256_castpd_si256(result);
249 Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(
a));
251 return _mm256_extractf32x4_ps(full, 0);
257 __m256 result = _mm256_undefined_ps();
258 result = _mm256_insertf128_ps(result,
a, 0);
259 result = _mm256_insertf128_ps(result,
b, 1);
260 return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
Packet8i preinterpret< Packet8i, Packet16i >(const Packet16i &a)
Packet8f pzero(const Packet8f &)
Packet4d preinterpret< Packet4d, Packet8d >(const Packet8d &a)
Packet8f preinterpret< Packet8f, Packet16f >(const Packet16f &a)
Packet16f pcast< Packet16bf, Packet16f >(const Packet16bf &a)
Packet8bf F32ToBf16(Packet4f p4f)
Packet16f pcast< Packet16i, Packet16f >(const Packet16i &a)
Packet8h preinterpret< Packet8h, Packet16h >(const Packet16h &a)
Packet16bf pcast< Packet16f, Packet16bf >(const Packet16f &a)
Packet8f Bf16ToF32(const Packet8bf &a)
eigen_packet_wrapper< __m128i, 1 > Packet16b
Packet4f preinterpret< Packet4f, Packet16f >(const Packet16f &a)
Packet4i preinterpret< Packet4i, Packet16i >(const Packet16i &a)
Packet16f pcast< Packet16h, Packet16f >(const Packet16h &a)
eigen_packet_wrapper< __m256i, 2 > Packet16bf
Packet8bf preinterpret< Packet8bf, Packet16bf >(const Packet16bf &a)
Packet16f preinterpret< Packet16f, Packet8f >(const Packet8f &a)
Packet8f pcast< Packet8d, Packet8f >(const Packet8d &a)
Packet16f preinterpret< Packet16f, Packet4f >(const Packet4f &a)
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Packet8h float2half(const Packet8f &a)
Packet8i pcast< Packet8d, Packet8i >(const Packet8d &a)
Packet16f pcast< Packet8d, Packet16f >(const Packet8d &a, const Packet8d &b)
Packet8f half2float(const Packet8h &a)
Packet16f pcast< Packet16b, Packet16f >(const Packet16b &a)
Packet8d preinterpret< Packet8d, Packet4d >(const Packet4d &a)
eigen_packet_wrapper< __m256i, 0 > Packet8i
Packet16f preinterpret< Packet16f, Packet8d >(const Packet8d &a)
Packet16b pcast< Packet16f, Packet16b >(const Packet16f &a)
Packet16i cat256i(Packet8i a, Packet8i b)
Packet16f cat256(Packet8f a, Packet8f b)
Packet16i pcast< Packet16f, Packet16i >(const Packet16f &a)
Packet8d preinterpret< Packet8d, Packet16f >(const Packet16f &a)
Packet8d preinterpret< Packet8d, Packet2d >(const Packet2d &a)
Packet16f preinterpret< Packet16f, Packet16i >(const Packet16i &a)
Packet16i preinterpret< Packet16i, Packet16f >(const Packet16f &a)
Packet2d preinterpret< Packet2d, Packet8d >(const Packet8d &a)
Packet16h pcast< Packet16f, Packet16h >(const Packet16f &a)
eigen_packet_wrapper< __m256i, 1 > Packet16h
Packet16i pcast< Packet8d, Packet16i >(const Packet8d &a, const Packet8d &b)
eigen_packet_wrapper< __m128i, 2 > Packet8h