10 #ifndef EIGEN_TYPE_CASTING_AVX_H
11 #define EIGEN_TYPE_CASTING_AVX_H
13 #include "../../InternalHeaderCheck.h"
19 #ifndef EIGEN_VECTORIZE_AVX512
21 struct type_casting_traits<
Eigen::half, float> {
31 struct type_casting_traits<float,
Eigen::half> {
40 struct type_casting_traits<bfloat16, float> {
49 struct type_casting_traits<float, bfloat16> {
58 struct type_casting_traits<float,
bool> {
68 return _mm256_cvttps_epi32(
a);
72 return _mm256_cvtepi32_ps(
a);
76 return _mm256_set_m128(_mm256_cvtpd_ps(
b), _mm256_cvtpd_ps(
a));
80 return _mm256_set_m128i(_mm256_cvttpd_epi32(
b), _mm256_cvttpd_epi32(
a));
84 return _mm256_cvtpd_ps(
a);
88 return _mm256_cvttpd_epi32(
a);
94 __m256 nonzero_a = _mm256_cmp_ps(
a,
pzero(
a), _CMP_NEQ_UQ);
95 __m256 nonzero_b = _mm256_cmp_ps(
b,
pzero(
b), _CMP_NEQ_UQ);
96 constexpr
char kFF =
'\255';
97 #ifndef EIGEN_VECTORIZE_AVX2
98 __m128i shuffle_mask128_a_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
99 __m128i shuffle_mask128_a_hi = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF);
100 __m128i shuffle_mask128_b_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
101 __m128i shuffle_mask128_b_hi = _mm_set_epi8(12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
102 __m128i a_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 1), shuffle_mask128_a_hi);
103 __m128i a_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 0), shuffle_mask128_a_lo);
104 __m128i b_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 1), shuffle_mask128_b_hi);
105 __m128i b_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 0), shuffle_mask128_b_lo);
106 __m128i merged = _mm_or_si128(_mm_or_si128(b_lo, b_hi), _mm_or_si128(a_lo, a_hi));
107 return _mm_and_si128(merged, _mm_set1_epi8(1));
109 __m256i a_shuffle_mask = _mm256_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF,
110 kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
111 __m256i b_shuffle_mask = _mm256_set_epi8( 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF,
112 kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
113 __m256i a_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_a), a_shuffle_mask);
114 __m256i b_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_b), b_shuffle_mask);
115 __m256i a_or_b = _mm256_or_si256(a_shuff, b_shuff);
116 __m256i merged = _mm256_or_si256(a_or_b, _mm256_castsi128_si256(_mm256_extractf128_si256(a_or_b, 1)));
117 return _mm256_castsi256_si128(_mm256_and_si256(merged, _mm256_set1_epi8(1)));
122 return _mm256_castps_si256(
a);
126 return _mm256_castsi256_ps(
a);
140 return _mm256_castps256_ps128(
a);
144 return _mm256_castpd256_pd128(
a);
148 return _mm256_castsi256_si128(
a);
152 return _mm256_castsi256_si128(
a);
156 #ifdef EIGEN_VECTORIZE_AVX2
157 template<> EIGEN_STRONG_INLINE Packet4ul preinterpret<Packet4ul, Packet4l>(
const Packet4l&
a) {
161 template<> EIGEN_STRONG_INLINE Packet4l preinterpret<Packet4l, Packet4ul>(
const Packet4ul&
a) {
Packet16b pcast< Packet8f, Packet16b >(const Packet8f &a, const Packet8f &b)
Packet8bf pcast< Packet8f, Packet8bf >(const Packet8f &a)
Packet8f pzero(const Packet8f &)
Packet4f preinterpret< Packet4f, Packet8f >(const Packet8f &a)
Packet8f pcast< Packet4d, Packet8f >(const Packet4d &a, const Packet4d &b)
Packet8bf F32ToBf16(Packet4f p4f)
Packet8i preinterpret< Packet8i, Packet8f >(const Packet8f &a)
Packet8f pcast< Packet8h, Packet8f >(const Packet8h &a)
Packet8f pcast< Packet8i, Packet8f >(const Packet8i &a)
Packet8f Bf16ToF32(const Packet8bf &a)
Packet4ui preinterpret< Packet4ui, Packet8ui >(const Packet8ui &a)
Packet8f preinterpret< Packet8f, Packet8i >(const Packet8i &a)
eigen_packet_wrapper< __m128i, 1 > Packet16b
Packet8i pcast< Packet4d, Packet8i >(const Packet4d &a, const Packet4d &b)
Packet8ui preinterpret< Packet8ui, Packet8i >(const Packet8i &a)
__vector unsigned int Packet4ui
Packet8h pcast< Packet8f, Packet8h >(const Packet8f &a)
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Packet8h float2half(const Packet8f &a)
Packet8f pcast< Packet8bf, Packet8f >(const Packet8bf &a)
Packet8f half2float(const Packet8h &a)
eigen_packet_wrapper< __m256i, 0 > Packet8i
Packet8i preinterpret< Packet8i, Packet8ui >(const Packet8ui &a)
eigen_packet_wrapper< __m256i, 4 > Packet8ui
Packet8i pcast< Packet8f, Packet8i >(const Packet8f &a)
Packet4i preinterpret< Packet4i, Packet8i >(const Packet8i &a)
Packet4f pcast< Packet4d, Packet4f >(const Packet4d &a)
Packet2d preinterpret< Packet2d, Packet4d >(const Packet4d &a)
Packet4i pcast< Packet4d, Packet4i >(const Packet4d &a)
eigen_packet_wrapper< __m128i, 2 > Packet8h