AVX/TypeCasting.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TYPE_CASTING_AVX_H
11 #define EIGEN_TYPE_CASTING_AVX_H
12 
13 #include "../../InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 namespace internal {
18 
19 #ifndef EIGEN_VECTORIZE_AVX512
20 template <>
21 struct type_casting_traits<Eigen::half, float> {
22  enum {
23  VectorizedCast = 1,
24  SrcCoeffRatio = 1,
25  TgtCoeffRatio = 1
26  };
27 };
28 
29 
30 template <>
31 struct type_casting_traits<float, Eigen::half> {
32  enum {
33  VectorizedCast = 1,
34  SrcCoeffRatio = 1,
35  TgtCoeffRatio = 1
36  };
37 };
38 
39 template <>
40 struct type_casting_traits<bfloat16, float> {
41  enum {
42  VectorizedCast = 1,
43  SrcCoeffRatio = 1,
44  TgtCoeffRatio = 1
45  };
46 };
47 
48 template <>
49 struct type_casting_traits<float, bfloat16> {
50  enum {
51  VectorizedCast = 1,
52  SrcCoeffRatio = 1,
53  TgtCoeffRatio = 1
54  };
55 };
56 
57 template <>
58 struct type_casting_traits<float, bool> {
59  enum {
60  VectorizedCast = 1,
61  SrcCoeffRatio = 2,
62  TgtCoeffRatio = 1
63  };
64 };
65 #endif // EIGEN_VECTORIZE_AVX512
66 
67 template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
68  return _mm256_cvttps_epi32(a);
69 }
70 
71 template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8i, Packet8f>(const Packet8i& a) {
72  return _mm256_cvtepi32_ps(a);
73 }
74 
75 template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet4d, Packet8f>(const Packet4d& a, const Packet4d& b) {
76  return _mm256_set_m128(_mm256_cvtpd_ps(b), _mm256_cvtpd_ps(a));
77 }
78 
79 template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet4d, Packet8i>(const Packet4d& a, const Packet4d& b) {
80  return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a));
81 }
82 
83 template <> EIGEN_STRONG_INLINE Packet4f pcast<Packet4d, Packet4f>(const Packet4d& a) {
84  return _mm256_cvtpd_ps(a);
85 }
86 
87 template <> EIGEN_STRONG_INLINE Packet4i pcast<Packet4d, Packet4i>(const Packet4d& a) {
88  return _mm256_cvttpd_epi32(a);
89 }
90 
91 template <>
92 EIGEN_STRONG_INLINE Packet16b pcast<Packet8f, Packet16b>(const Packet8f& a,
93  const Packet8f& b) {
94  __m256 nonzero_a = _mm256_cmp_ps(a, pzero(a), _CMP_NEQ_UQ);
95  __m256 nonzero_b = _mm256_cmp_ps(b, pzero(b), _CMP_NEQ_UQ);
96  constexpr char kFF = '\255';
97 #ifndef EIGEN_VECTORIZE_AVX2
98  __m128i shuffle_mask128_a_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
99  __m128i shuffle_mask128_a_hi = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF);
100  __m128i shuffle_mask128_b_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
101  __m128i shuffle_mask128_b_hi = _mm_set_epi8(12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
102  __m128i a_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 1), shuffle_mask128_a_hi);
103  __m128i a_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 0), shuffle_mask128_a_lo);
104  __m128i b_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 1), shuffle_mask128_b_hi);
105  __m128i b_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 0), shuffle_mask128_b_lo);
106  __m128i merged = _mm_or_si128(_mm_or_si128(b_lo, b_hi), _mm_or_si128(a_lo, a_hi));
107  return _mm_and_si128(merged, _mm_set1_epi8(1));
108  #else
109  __m256i a_shuffle_mask = _mm256_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF,
110  kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
111  __m256i b_shuffle_mask = _mm256_set_epi8( 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF,
112  kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
113  __m256i a_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_a), a_shuffle_mask);
114  __m256i b_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_b), b_shuffle_mask);
115  __m256i a_or_b = _mm256_or_si256(a_shuff, b_shuff);
116  __m256i merged = _mm256_or_si256(a_or_b, _mm256_castsi128_si256(_mm256_extractf128_si256(a_or_b, 1)));
117  return _mm256_castsi256_si128(_mm256_and_si256(merged, _mm256_set1_epi8(1)));
118 #endif
119 }
120 
121 template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i,Packet8f>(const Packet8f& a) {
122  return _mm256_castps_si256(a);
123 }
124 
125 template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f,Packet8i>(const Packet8i& a) {
126  return _mm256_castsi256_ps(a);
127 }
128 
129 template<> EIGEN_STRONG_INLINE Packet8ui preinterpret<Packet8ui, Packet8i>(const Packet8i& a) {
130  return Packet8ui(a);
131 }
132 
133 template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i, Packet8ui>(const Packet8ui& a) {
134  return Packet8i(a);
135 }
136 
137 // truncation operations
138 
139 template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet8f>(const Packet8f& a) {
140  return _mm256_castps256_ps128(a);
141 }
142 
143 template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4d>(const Packet4d& a) {
144  return _mm256_castpd256_pd128(a);
145 }
146 
147 template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet8i>(const Packet8i& a) {
148  return _mm256_castsi256_si128(a);
149 }
150 
151 template<> EIGEN_STRONG_INLINE Packet4ui preinterpret<Packet4ui, Packet8ui>(const Packet8ui& a) {
152  return _mm256_castsi256_si128(a);
153 }
154 
155 
156 #ifdef EIGEN_VECTORIZE_AVX2
157 template<> EIGEN_STRONG_INLINE Packet4ul preinterpret<Packet4ul, Packet4l>(const Packet4l& a) {
158  return Packet4ul(a);
159 }
160 
161 template<> EIGEN_STRONG_INLINE Packet4l preinterpret<Packet4l, Packet4ul>(const Packet4ul& a) {
162  return Packet4l(a);
163 }
164 
165 #endif
166 
167 template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
168  return half2float(a);
169 }
170 
171 template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
172  return Bf16ToF32(a);
173 }
174 
175 template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
176  return float2half(a);
177 }
178 
179 template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
180  return F32ToBf16(a);
181 }
182 
183 } // end namespace internal
184 
185 } // end namespace Eigen
186 
187 #endif // EIGEN_TYPE_CASTING_AVX_H
Array< int, 3, 1 > b
Packet16b pcast< Packet8f, Packet16b >(const Packet8f &a, const Packet8f &b)
Packet8bf pcast< Packet8f, Packet8bf >(const Packet8f &a)
Packet8f pzero(const Packet8f &)
Packet4f preinterpret< Packet4f, Packet8f >(const Packet8f &a)
__vector int Packet4i
Packet8f pcast< Packet4d, Packet8f >(const Packet4d &a, const Packet4d &b)
Packet8bf F32ToBf16(Packet4f p4f)
Packet8i preinterpret< Packet8i, Packet8f >(const Packet8f &a)
Packet8f pcast< Packet8h, Packet8f >(const Packet8h &a)
Packet8f pcast< Packet8i, Packet8f >(const Packet8i &a)
Packet8f Bf16ToF32(const Packet8bf &a)
Packet4ui preinterpret< Packet4ui, Packet8ui >(const Packet8ui &a)
Packet8f preinterpret< Packet8f, Packet8i >(const Packet8i &a)
eigen_packet_wrapper< __m128i, 1 > Packet16b
Packet8i pcast< Packet4d, Packet8i >(const Packet4d &a, const Packet4d &b)
Packet8ui preinterpret< Packet8ui, Packet8i >(const Packet8i &a)
__vector unsigned int Packet4ui
Packet8h pcast< Packet8f, Packet8h >(const Packet8f &a)
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Packet8h float2half(const Packet8f &a)
Packet8f pcast< Packet8bf, Packet8f >(const Packet8bf &a)
Packet8f half2float(const Packet8h &a)
eigen_packet_wrapper< __m256i, 0 > Packet8i
Packet8i preinterpret< Packet8i, Packet8ui >(const Packet8ui &a)
eigen_packet_wrapper< __m256i, 4 > Packet8ui
Packet8i pcast< Packet8f, Packet8i >(const Packet8f &a)
Packet4i preinterpret< Packet4i, Packet8i >(const Packet8i &a)
__vector float Packet4f
Packet4f pcast< Packet4d, Packet4f >(const Packet4d &a)
Packet2d preinterpret< Packet2d, Packet4d >(const Packet4d &a)
Packet4i pcast< Packet4d, Packet4i >(const Packet4d &a)
eigen_packet_wrapper< __m128i, 2 > Packet8h
: InteropHeaders
Definition: Core:139