10 #ifndef EIGEN_COMPLEX_AVX512_H
11 #define EIGEN_COMPLEX_AVX512_H
13 #include "../../InternalHeaderCheck.h"
22 EIGEN_STRONG_INLINE Packet8cf() {}
23 EIGEN_STRONG_INLINE
explicit Packet8cf(
const __m512&
a) :
v(
a) {}
27 template<>
struct packet_traits<
std::complex<float> > : default_packet_traits
29 typedef Packet8cf type;
30 typedef Packet4cf half;
50 template<>
struct unpacket_traits<Packet8cf> {
51 typedef std::complex<float> type;
52 typedef Packet4cf half;
56 alignment=unpacket_traits<Packet16f>::alignment,
58 masked_load_available=
false,
59 masked_store_available=
false
64 template<> EIGEN_STRONG_INLINE Packet8cf
padd<Packet8cf>(
const Packet8cf&
a,
const Packet8cf&
b) {
return Packet8cf(_mm512_add_ps(
a.v,
b.v)); }
65 template<> EIGEN_STRONG_INLINE Packet8cf
psub<Packet8cf>(
const Packet8cf&
a,
const Packet8cf&
b) {
return Packet8cf(_mm512_sub_ps(
a.v,
b.v)); }
66 template<> EIGEN_STRONG_INLINE Packet8cf
pnegate(
const Packet8cf&
a)
70 template<> EIGEN_STRONG_INLINE Packet8cf
pconj(
const Packet8cf&
a)
72 const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32(
73 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,
74 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000));
75 return Packet8cf(
pxor(
a.v,mask));
78 template<> EIGEN_STRONG_INLINE Packet8cf
pmul<Packet8cf>(
const Packet8cf&
a,
const Packet8cf&
b)
80 __m512 tmp2 = _mm512_mul_ps(_mm512_movehdup_ps(
a.v), _mm512_permute_ps(
b.v, _MM_SHUFFLE(2,3,0,1)));
81 return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(
a.v),
b.v, tmp2));
84 template<> EIGEN_STRONG_INLINE Packet8cf
pand <Packet8cf>(
const Packet8cf&
a,
const Packet8cf&
b) {
return Packet8cf(
pand(
a.v,
b.v)); }
85 template<> EIGEN_STRONG_INLINE Packet8cf
por <Packet8cf>(
const Packet8cf&
a,
const Packet8cf&
b) {
return Packet8cf(
por(
a.v,
b.v)); }
86 template<> EIGEN_STRONG_INLINE Packet8cf
pxor <Packet8cf>(
const Packet8cf&
a,
const Packet8cf&
b) {
return Packet8cf(
pxor(
a.v,
b.v)); }
90 EIGEN_STRONG_INLINE Packet8cf
pcmp_eq(
const Packet8cf&
a,
const Packet8cf&
b) {
91 __m512 eq = pcmp_eq<Packet16f>(
a.v,
b.v);
92 return Packet8cf(
pand(eq, _mm512_permute_ps(eq, 0xB1)));
99 template<> EIGEN_STRONG_INLINE Packet8cf
pset1<Packet8cf>(
const std::complex<float>& from)
103 return Packet8cf(_mm512_set_ps(im, re, im, re, im, re, im, re, im, re, im, re, im, re, im, re));
108 return Packet8cf( _mm512_castpd_ps(
ploaddup<Packet8d>((
const double*)(
const void*)from )) );
112 return Packet8cf( _mm512_castpd_ps(
ploadquad<Packet8d>((
const double*)(
const void*)from )) );
118 template<>
EIGEN_DEVICE_FUNC inline Packet8cf pgather<std::complex<float>, Packet8cf>(
const std::complex<float>* from,
Index stride)
123 template<>
EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet8cf>(std::complex<float>* to,
const Packet8cf& from,
Index stride)
125 pscatter((
double*)(
void*)to, _mm512_castps_pd(from.v), stride);
128 template<> EIGEN_STRONG_INLINE std::complex<float>
pfirst<Packet8cf>(
const Packet8cf&
a)
130 return pfirst(Packet2cf(_mm512_castps512_ps128(
a.v)));
133 template<> EIGEN_STRONG_INLINE Packet8cf
preverse(
const Packet8cf&
a) {
134 return Packet8cf(_mm512_castsi512_ps(
135 _mm512_permutexvar_epi64( _mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7),
136 _mm512_castps_si512(
a.v))));
139 template<> EIGEN_STRONG_INLINE std::complex<float>
predux<Packet8cf>(
const Packet8cf&
a)
142 Packet4cf(extract256<1>(
a.v))));
148 Packet4cf(extract256<1>(
a.v))));
153 __m256 lane0 = extract256<0>(
a.v);
154 __m256 lane1 = extract256<1>(
a.v);
155 __m256
res = _mm256_add_ps(lane0, lane1);
156 return Packet4cf(
res);
161 template<> EIGEN_STRONG_INLINE Packet8cf
pdiv<Packet8cf>(const Packet8cf&
a, const Packet8cf&
b)
168 return Packet8cf(_mm512_shuffle_ps(
x.v,
x.v, _MM_SHUFFLE(2, 3, 0 ,1)));
174 EIGEN_STRONG_INLINE Packet4cd() {}
175 EIGEN_STRONG_INLINE
explicit Packet4cd(
const __m512d&
a) :
v(
a) {}
179 template<>
struct packet_traits<
std::complex<double> > : default_packet_traits
181 typedef Packet4cd type;
182 typedef Packet2cd half;
202 template<>
struct unpacket_traits<Packet4cd> {
203 typedef std::complex<double> type;
204 typedef Packet2cd half;
208 alignment = unpacket_traits<Packet8d>::alignment,
210 masked_load_available=
false,
211 masked_store_available=
false
215 template<> EIGEN_STRONG_INLINE Packet4cd
padd<Packet4cd>(
const Packet4cd&
a,
const Packet4cd&
b) {
return Packet4cd(_mm512_add_pd(
a.v,
b.v)); }
216 template<> EIGEN_STRONG_INLINE Packet4cd
psub<Packet4cd>(
const Packet4cd&
a,
const Packet4cd&
b) {
return Packet4cd(_mm512_sub_pd(
a.v,
b.v)); }
217 template<> EIGEN_STRONG_INLINE Packet4cd
pnegate(
const Packet4cd&
a) {
return Packet4cd(
pnegate(
a.v)); }
218 template<> EIGEN_STRONG_INLINE Packet4cd
pconj(
const Packet4cd&
a)
220 const __m512d mask = _mm512_castsi512_pd(
221 _mm512_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0,
222 0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0));
223 return Packet4cd(
pxor(
a.v,mask));
226 template<> EIGEN_STRONG_INLINE Packet4cd
pmul<Packet4cd>(
const Packet4cd&
a,
const Packet4cd&
b)
228 __m512d tmp1 = _mm512_shuffle_pd(
a.v,
a.v,0x0);
229 __m512d tmp2 = _mm512_shuffle_pd(
a.v,
a.v,0xFF);
230 __m512d tmp3 = _mm512_shuffle_pd(
b.v,
b.v,0x55);
231 __m512d odd = _mm512_mul_pd(tmp2, tmp3);
232 return Packet4cd(_mm512_fmaddsub_pd(tmp1,
b.v, odd));
236 template<> EIGEN_STRONG_INLINE Packet4cd
pand <Packet4cd>(
const Packet4cd&
a,
const Packet4cd&
b) {
return Packet4cd(
pand(
a.v,
b.v)); }
237 template<> EIGEN_STRONG_INLINE Packet4cd
por <Packet4cd>(
const Packet4cd&
a,
const Packet4cd&
b) {
return Packet4cd(
por(
a.v,
b.v)); }
238 template<> EIGEN_STRONG_INLINE Packet4cd
pxor <Packet4cd>(
const Packet4cd&
a,
const Packet4cd&
b) {
return Packet4cd(
pxor(
a.v,
b.v)); }
242 EIGEN_STRONG_INLINE Packet4cd
pcmp_eq(
const Packet4cd&
a,
const Packet4cd&
b) {
243 __m512d eq = pcmp_eq<Packet8d>(
a.v,
b.v);
244 return Packet4cd(
pand(eq, _mm512_permute_pd(eq, 0x55)));
252 template<> EIGEN_STRONG_INLINE Packet4cd
pset1<Packet4cd>(
const std::complex<double>& from)
254 return Packet4cd(_mm512_castps_pd(_mm512_broadcast_f32x4( _mm_castpd_ps(
pset1<Packet1cd>(from).
v))));
258 return Packet4cd(_mm512_insertf64x4(
262 template<> EIGEN_STRONG_INLINE
void pstore <std::complex<double> >(std::complex<double> * to,
const Packet4cd& from) {
EIGEN_DEBUG_ALIGNED_STORE pstore((
double*)to, from.v); }
265 template<>
EIGEN_DEVICE_FUNC inline Packet4cd pgather<std::complex<double>, Packet4cd>(
const std::complex<double>* from,
Index stride)
267 return Packet4cd(_mm512_insertf64x4(_mm512_castpd256_pd512(
272 template<>
EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet4cd>(std::complex<double>* to,
const Packet4cd& from,
Index stride)
274 __m512i fromi = _mm512_castpd_si512(from.v);
275 double* tod = (
double*)(
void*)to;
276 _mm_storeu_pd(tod+0*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,0)) );
277 _mm_storeu_pd(tod+2*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,1)) );
278 _mm_storeu_pd(tod+4*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,2)) );
279 _mm_storeu_pd(tod+6*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,3)) );
282 template<> EIGEN_STRONG_INLINE std::complex<double>
pfirst<Packet4cd>(
const Packet4cd&
a)
284 __m128d low = extract128<0>(
a.v);
286 _mm_store_pd(
res, low);
287 return std::complex<double>(
res[0],
res[1]);
290 template<> EIGEN_STRONG_INLINE Packet4cd
preverse(
const Packet4cd&
a) {
291 return Packet4cd(_mm512_shuffle_f64x2(
a.v,
a.v, (shuffle_mask<3,2,1,0>::mask)));
294 template<> EIGEN_STRONG_INLINE std::complex<double>
predux<Packet4cd>(
const Packet4cd&
a)
296 return predux(
padd(Packet2cd(_mm512_extractf64x4_pd(
a.v,0)),
297 Packet2cd(_mm512_extractf64x4_pd(
a.v,1))));
303 Packet2cd(_mm512_extractf64x4_pd(
a.v,1))));
308 template<> EIGEN_STRONG_INLINE Packet4cd
pdiv<Packet4cd>(const Packet4cd&
a, const Packet4cd&
b)
315 return Packet4cd(_mm512_permute_pd(
x.v,0x55));
320 PacketBlock<Packet8d,4> pb;
322 pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
323 pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
324 pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
325 pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
327 kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
328 kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
329 kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
330 kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
335 PacketBlock<Packet8d,8> pb;
337 pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
338 pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
339 pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
340 pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
341 pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v);
342 pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v);
343 pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v);
344 pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v);
346 kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
347 kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
348 kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
349 kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
350 kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]);
351 kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]);
352 kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]);
353 kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]);
358 __m512d T0 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<0,1,0,1>::mask));
359 __m512d T1 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<2,3,2,3>::mask));
360 __m512d T2 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<0,1,0,1>::mask));
361 __m512d T3 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<2,3,2,3>::mask));
363 kernel.packet[3] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<1,3,1,3>::mask)));
364 kernel.packet[2] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<0,2,0,2>::mask)));
365 kernel.packet[1] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<1,3,1,3>::mask)));
366 kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0,2,0,2>::mask)));
370 return psqrt_complex<Packet4cd>(
a);
374 return psqrt_complex<Packet8cf>(
a);
Array< int, Dynamic, 1 > v
const ImagReturnType imag() const
RealReturnType real() const
#define EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(PACKET_CPLX, PACKET_REAL)
#define EIGEN_DEBUG_ALIGNED_STORE
#define EIGEN_DEBUG_ALIGNED_LOAD
#define EIGEN_DEBUG_UNALIGNED_STORE
#define EIGEN_DEBUG_UNALIGNED_LOAD
#define EIGEN_DEVICE_FUNC
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Packet8cf pcplxflip< Packet8cf >(const Packet8cf &x)
Packet padd(const Packet &a, const Packet &b)
void pstore(Scalar *to, const Packet &from)
Packet4cd pmul< Packet4cd >(const Packet4cd &a, const Packet4cd &b)
Packet4cd por< Packet4cd >(const Packet4cd &a, const Packet4cd &b)
Packet8cf ploadquad< Packet8cf >(const std::complex< float > *from)
unpacket_traits< Packet >::type predux(const Packet &a)
Packet8h ptrue(const Packet8h &a)
Packet8d ploadquad< Packet8d >(const double *from)
Packet8d ploadu< Packet8d >(const double *from)
Packet8cf ptrue< Packet8cf >(const Packet8cf &a)
Packet4cd ploaddup< Packet4cd >(const std::complex< double > *from)
Packet8h pandnot(const Packet8h &a, const Packet8h &b)
Packet8cf por< Packet8cf >(const Packet8cf &a, const Packet8cf &b)
Packet2cf pnegate(const Packet2cf &a)
Packet1cd ploadu< Packet1cd >(const std::complex< double > *from)
std::complex< float > predux< Packet8cf >(const Packet8cf &a)
Packet2cd ploaddup< Packet2cd >(const std::complex< double > *from)
void pstoreu(Scalar *to, const Packet &from)
Packet8cf pmul< Packet8cf >(const Packet8cf &a, const Packet8cf &b)
Packet2cf pcmp_eq(const Packet2cf &a, const Packet2cf &b)
bfloat16 pfirst(const Packet8bf &a)
std::complex< double > predux< Packet4cd >(const Packet4cd &a)
Packet pmul(const Packet &a, const Packet &b)
void pscatter(Scalar *to, const Packet &from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Packet4cd pandnot< Packet4cd >(const Packet4cd &a, const Packet4cd &b)
Packet16f ploadu< Packet16f >(const float *from)
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pdiv_complex(const Packet &x, const Packet &y)
Packet4cd pand< Packet4cd >(const Packet4cd &a, const Packet4cd &b)
Packet4cf predux_half_dowto4< Packet8cf >(const Packet8cf &a)
Packet4cd padd< Packet4cd >(const Packet4cd &a, const Packet4cd &b)
Packet4cd psub< Packet4cd >(const Packet4cd &a, const Packet4cd &b)
Packet8cf pandnot< Packet8cf >(const Packet8cf &a, const Packet8cf &b)
unpacket_traits< Packet >::type predux_mul(const Packet &a)
Packet8d pload< Packet8d >(const double *from)
Packet8h pand(const Packet8h &a, const Packet8h &b)
Packet8cf pxor< Packet8cf >(const Packet8cf &a, const Packet8cf &b)
std::complex< double > pfirst< Packet4cd >(const Packet4cd &a)
Packet4cd pcplxflip< Packet4cd >(const Packet4cd &x)
Packet4cd pload< Packet4cd >(const std::complex< double > *from)
std::complex< float > predux_mul< Packet8cf >(const Packet8cf &a)
Packet8h pxor(const Packet8h &a, const Packet8h &b)
Packet8cf padd< Packet8cf >(const Packet8cf &a, const Packet8cf &b)
Packet pdiv(const Packet &a, const Packet &b)
Packet8d ploaddup< Packet8d >(const double *from)
Packet1cd pset1< Packet1cd >(const std::complex< double > &from)
Packet2cf pconj(const Packet2cf &a)
Packet8cf pset1< Packet8cf >(const std::complex< float > &from)
Packet8cf pload< Packet8cf >(const std::complex< float > *from)
std::complex< float > pfirst< Packet8cf >(const Packet8cf &a)
Packet4cd ploadu< Packet4cd >(const std::complex< double > *from)
Packet2cf preverse(const Packet2cf &a)
Packet4cd pxor< Packet4cd >(const Packet4cd &a, const Packet4cd &b)
Packet4cd ptrue< Packet4cd >(const Packet4cd &a)
Packet8d pgather< double, Packet8d >(const Packet8d &src, const double *from, Index stride, uint8_t umask)
Packet8h por(const Packet8h &a, const Packet8h &b)
Packet16f pload< Packet16f >(const float *from)
Packet8cf ploadu< Packet8cf >(const std::complex< float > *from)
Packet4cd pset1< Packet4cd >(const std::complex< double > &from)
Packet8cf psqrt< Packet8cf >(const Packet8cf &a)
Packet8cf pand< Packet8cf >(const Packet8cf &a, const Packet8cf &b)
std::complex< double > predux_mul< Packet4cd >(const Packet4cd &a)
Packet8cf psub< Packet8cf >(const Packet8cf &a, const Packet8cf &b)
Packet8cf ploaddup< Packet8cf >(const std::complex< float > *from)
Packet4cd psqrt< Packet4cd >(const Packet4cd &a)
internal::add_const_on_value_type_t< EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) > real_ref(const Scalar &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.