SVE/PacketMath.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020, Arm Limited and Contributors
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_PACKET_MATH_SVE_H
11 #define EIGEN_PACKET_MATH_SVE_H
12 
13 #include "../../InternalHeaderCheck.h"
14 
15 namespace Eigen
16 {
17 namespace internal
18 {
19 #ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
20 #define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
21 #endif
22 
23 #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
24 #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
25 #endif
26 
27 #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
28 
29 template <typename Scalar, int SVEVectorLength>
30 struct sve_packet_size_selector {
31  enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
32 };
33 
34 
35 typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
36 
37 template <>
38 struct packet_traits<numext::int32_t> : default_packet_traits {
39  typedef PacketXi type;
40  typedef PacketXi half; // Half not implemented yet
41  enum {
42  Vectorizable = 1,
43  AlignedOnScalar = 1,
45 
46  HasAdd = 1,
47  HasSub = 1,
48  HasShift = 1,
49  HasMul = 1,
50  HasNegate = 1,
51  HasAbs = 1,
52  HasArg = 0,
53  HasAbs2 = 1,
54  HasMin = 1,
55  HasMax = 1,
56  HasConj = 1,
57  HasSetLinear = 0,
58  HasBlend = 0,
59  HasReduxp = 0 // Not implemented in SVE
60  };
61 };
62 
63 template <>
64 struct unpacket_traits<PacketXi> {
65  typedef numext::int32_t type;
66  typedef PacketXi half; // Half not yet implemented
67  enum {
69  alignment = Aligned64,
70  vectorizable = true,
71  masked_load_available = false,
72  masked_store_available = false
73  };
74 };
75 
76 template <>
77 EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr)
78 {
79  svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
80 }
81 
82 template <>
83 EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from)
84 {
85  return svdup_n_s32(from);
86 }
87 
88 template <>
89 EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a)
90 {
92  for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
93  return svadd_s32_z(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
94 }
95 
96 template <>
97 EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b)
98 {
99  return svadd_s32_z(svptrue_b32(), a, b);
100 }
101 
102 template <>
103 EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b)
104 {
105  return svsub_s32_z(svptrue_b32(), a, b);
106 }
107 
108 template <>
109 EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a)
110 {
111  return svneg_s32_z(svptrue_b32(), a);
112 }
113 
114 template <>
115 EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a)
116 {
117  return a;
118 }
119 
120 template <>
121 EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b)
122 {
123  return svmul_s32_z(svptrue_b32(), a, b);
124 }
125 
126 template <>
127 EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b)
128 {
129  return svdiv_s32_z(svptrue_b32(), a, b);
130 }
131 
132 template <>
133 EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c)
134 {
135  return svmla_s32_z(svptrue_b32(), c, a, b);
136 }
137 
138 template <>
139 EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b)
140 {
141  return svmin_s32_z(svptrue_b32(), a, b);
142 }
143 
144 template <>
145 EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b)
146 {
147  return svmax_s32_z(svptrue_b32(), a, b);
148 }
149 
150 template <>
151 EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b)
152 {
153  return svdup_n_s32_z(svcmple_s32(svptrue_b32(), a, b), 0xffffffffu);
154 }
155 
156 template <>
157 EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b)
158 {
159  return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
160 }
161 
162 template <>
163 EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b)
164 {
165  return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
166 }
167 
168 template <>
169 EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/)
170 {
171  return svdup_n_s32_z(svptrue_b32(), 0xffffffffu);
172 }
173 
174 template <>
175 EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/)
176 {
177  return svdup_n_s32_z(svptrue_b32(), 0);
178 }
179 
180 template <>
181 EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b)
182 {
183  return svand_s32_z(svptrue_b32(), a, b);
184 }
185 
186 template <>
187 EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b)
188 {
189  return svorr_s32_z(svptrue_b32(), a, b);
190 }
191 
192 template <>
193 EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b)
194 {
195  return sveor_s32_z(svptrue_b32(), a, b);
196 }
197 
198 template <>
199 EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b)
200 {
201  return svbic_s32_z(svptrue_b32(), a, b);
202 }
203 
204 template <int N>
205 EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a)
206 {
207  return svasrd_n_s32_z(svptrue_b32(), a, N);
208 }
209 
210 template <int N>
211 EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a)
212 {
213  return svreinterpret_s32_u32(svlsr_n_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), N));
214 }
215 
216 template <int N>
217 EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a)
218 {
219  return svlsl_n_s32_z(svptrue_b32(), a, N);
220 }
221 
222 template <>
223 EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from)
224 {
225  EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
226 }
227 
228 template <>
229 EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from)
230 {
231  EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
232 }
233 
234 template <>
235 EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from)
236 {
237  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
238  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
239  return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
240 }
241 
242 template <>
243 EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from)
244 {
245  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
246  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
247  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
248  return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
249 }
250 
251 template <>
252 EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
253 {
254  EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
255 }
256 
257 template <>
258 EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
259 {
260  EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
261 }
262 
263 template <>
264 EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride)
265 {
266  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
267  svint32_t indices = svindex_s32(0, stride);
268  return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
269 }
270 
271 template <>
272 EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from, Index stride)
273 {
274  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
275  svint32_t indices = svindex_s32(0, stride);
276  svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
277 }
278 
279 template <>
280 EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a)
281 {
282  // svlasta returns the first element if all predicate bits are 0
283  return svlasta_s32(svpfalse_b(), a);
284 }
285 
286 template <>
287 EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a)
288 {
289  return svrev_s32(a);
290 }
291 
292 template <>
293 EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a)
294 {
295  return svabs_s32_z(svptrue_b32(), a);
296 }
297 
298 template <>
299 EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a)
300 {
301  return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
302 }
303 
304 template <>
305 EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a)
306 {
307  EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
308  EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
309 
310  // Multiply the vector by its reverse
311  svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a));
312  svint32_t half_prod;
313 
314  // Extract the high half of the vector. Depending on the VL more reductions need to be done
315  if (EIGEN_ARM64_SVE_VL >= 2048) {
316  half_prod = svtbl_s32(prod, svindex_u32(32, 1));
317  prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
318  }
319  if (EIGEN_ARM64_SVE_VL >= 1024) {
320  half_prod = svtbl_s32(prod, svindex_u32(16, 1));
321  prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
322  }
323  if (EIGEN_ARM64_SVE_VL >= 512) {
324  half_prod = svtbl_s32(prod, svindex_u32(8, 1));
325  prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
326  }
327  if (EIGEN_ARM64_SVE_VL >= 256) {
328  half_prod = svtbl_s32(prod, svindex_u32(4, 1));
329  prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
330  }
331  // Last reduction
332  half_prod = svtbl_s32(prod, svindex_u32(2, 1));
333  prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
334 
335  // The reduction is done to the first element.
336  return pfirst<PacketXi>(prod);
337 }
338 
339 template <>
340 EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a)
341 {
342  return svminv_s32(svptrue_b32(), a);
343 }
344 
345 template <>
346 EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a)
347 {
348  return svmaxv_s32(svptrue_b32(), a);
349 }
350 
351 template <int N>
352 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
353  int buffer[packet_traits<numext::int32_t>::size * N] = {0};
354  int i = 0;
355 
356  PacketXi stride_index = svindex_s32(0, N);
357 
358  for (i = 0; i < N; i++) {
359  svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
360  }
361  for (i = 0; i < N; i++) {
362  kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
363  }
364 }
365 
366 
368 typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
369 
370 template <>
371 struct packet_traits<float> : default_packet_traits {
372  typedef PacketXf type;
373  typedef PacketXf half;
374 
375  enum {
376  Vectorizable = 1,
377  AlignedOnScalar = 1,
379 
380  HasAdd = 1,
381  HasSub = 1,
382  HasShift = 1,
383  HasMul = 1,
384  HasNegate = 1,
385  HasAbs = 1,
386  HasArg = 0,
387  HasAbs2 = 1,
388  HasMin = 1,
389  HasMax = 1,
390  HasConj = 1,
391  HasSetLinear = 0,
392  HasBlend = 0,
393  HasReduxp = 0, // Not implemented in SVE
394 
395  HasDiv = 1,
396  HasFloor = 1,
397 
398  HasSin = EIGEN_FAST_MATH,
399  HasCos = EIGEN_FAST_MATH,
400  HasLog = 1,
401  HasExp = 1,
402  HasSqrt = 0,
403  HasTanh = EIGEN_FAST_MATH,
404  HasErf = EIGEN_FAST_MATH
405  };
406 };
407 
408 template <>
409 struct unpacket_traits<PacketXf> {
410  typedef float type;
411  typedef PacketXf half; // Half not yet implemented
412  typedef PacketXi integer_packet;
413 
414  enum {
416  alignment = Aligned64,
417  vectorizable = true,
418  masked_load_available = false,
419  masked_store_available = false
420  };
421 };
422 
423 template <>
424 EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from)
425 {
426  return svdup_n_f32(from);
427 }
428 
429 template <>
430 EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from)
431 {
432  return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from));
433 }
434 
435 template <>
436 EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a)
437 {
439  for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
440  return svadd_f32_z(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
441 }
442 
443 template <>
444 EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b)
445 {
446  return svadd_f32_z(svptrue_b32(), a, b);
447 }
448 
449 template <>
450 EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b)
451 {
452  return svsub_f32_z(svptrue_b32(), a, b);
453 }
454 
455 template <>
456 EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a)
457 {
458  return svneg_f32_z(svptrue_b32(), a);
459 }
460 
461 template <>
462 EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a)
463 {
464  return a;
465 }
466 
467 template <>
468 EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b)
469 {
470  return svmul_f32_z(svptrue_b32(), a, b);
471 }
472 
473 template <>
474 EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b)
475 {
476  return svdiv_f32_z(svptrue_b32(), a, b);
477 }
478 
479 template <>
480 EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c)
481 {
482  return svmla_f32_z(svptrue_b32(), c, a, b);
483 }
484 
485 template <>
486 EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b)
487 {
488  return svmin_f32_z(svptrue_b32(), a, b);
489 }
490 
491 template <>
492 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
493 {
494  return pmin<PacketXf>(a, b);
495 }
496 
497 template <>
498 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
499 {
500  return svminnm_f32_z(svptrue_b32(), a, b);
501 }
502 
503 template <>
504 EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b)
505 {
506  return svmax_f32_z(svptrue_b32(), a, b);
507 }
508 
509 template <>
510 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
511 {
512  return pmax<PacketXf>(a, b);
513 }
514 
515 template <>
516 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
517 {
518  return svmaxnm_f32_z(svptrue_b32(), a, b);
519 }
520 
521 // Float comparisons in SVE return svbool (predicate). Use svdup to set active
522 // lanes to 1 (0xffffffffu) and inactive lanes to 0.
523 template <>
524 EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b)
525 {
526  return svreinterpret_f32_u32(svdup_n_u32_z(svcmple_f32(svptrue_b32(), a, b), 0xffffffffu));
527 }
528 
529 template <>
530 EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b)
531 {
532  return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
533 }
534 
535 template <>
536 EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b)
537 {
538  return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
539 }
540 
541 // Do a predicate inverse (svnot_b_z) on the predicate resulted from the
542 // greater/equal comparison (svcmpge_f32). Then fill a float vector with the
543 // active elements.
544 template <>
545 EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b)
546 {
547  return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
548 }
549 
550 template <>
551 EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a)
552 {
553  return svrintm_f32_z(svptrue_b32(), a);
554 }
555 
556 template <>
557 EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/)
558 {
559  return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu));
560 }
561 
562 // Logical Operations are not supported for float, so reinterpret casts
563 template <>
564 EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b)
565 {
566  return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
567 }
568 
569 template <>
570 EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b)
571 {
572  return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
573 }
574 
575 template <>
576 EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b)
577 {
578  return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
579 }
580 
581 template <>
582 EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b)
583 {
584  return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
585 }
586 
587 template <>
588 EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from)
589 {
590  EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
591 }
592 
593 template <>
594 EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from)
595 {
596  EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
597 }
598 
599 template <>
600 EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from)
601 {
602  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
603  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
604  return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
605 }
606 
607 template <>
608 EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from)
609 {
610  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
611  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
612  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
613  return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
614 }
615 
616 template <>
617 EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from)
618 {
619  EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
620 }
621 
622 template <>
623 EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from)
624 {
625  EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
626 }
627 
628 template <>
629 EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride)
630 {
631  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
632  svint32_t indices = svindex_s32(0, stride);
633  return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
634 }
635 
636 template <>
637 EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride)
638 {
639  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
640  svint32_t indices = svindex_s32(0, stride);
641  svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
642 }
643 
644 template <>
645 EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a)
646 {
647  // svlasta returns the first element if all predicate bits are 0
648  return svlasta_f32(svpfalse_b(), a);
649 }
650 
651 template <>
652 EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a)
653 {
654  return svrev_f32(a);
655 }
656 
657 template <>
658 EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a)
659 {
660  return svabs_f32_z(svptrue_b32(), a);
661 }
662 
663 // TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
664 // all vector extensions and the generic version.
665 template <>
666 EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent)
667 {
668  return pfrexp_generic(a, exponent);
669 }
670 
671 template <>
672 EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a)
673 {
674  return svaddv_f32(svptrue_b32(), a);
675 }
676 
677 // Other reduction functions:
678 // mul
679 // Only works for SVE Vls multiple of 128
680 template <>
681 EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a)
682 {
683  EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
684  EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
685  // Multiply the vector by its reverse
686  svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a));
687  svfloat32_t half_prod;
688 
689  // Extract the high half of the vector. Depending on the VL more reductions need to be done
690  if (EIGEN_ARM64_SVE_VL >= 2048) {
691  half_prod = svtbl_f32(prod, svindex_u32(32, 1));
692  prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
693  }
694  if (EIGEN_ARM64_SVE_VL >= 1024) {
695  half_prod = svtbl_f32(prod, svindex_u32(16, 1));
696  prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
697  }
698  if (EIGEN_ARM64_SVE_VL >= 512) {
699  half_prod = svtbl_f32(prod, svindex_u32(8, 1));
700  prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
701  }
702  if (EIGEN_ARM64_SVE_VL >= 256) {
703  half_prod = svtbl_f32(prod, svindex_u32(4, 1));
704  prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
705  }
706  // Last reduction
707  half_prod = svtbl_f32(prod, svindex_u32(2, 1));
708  prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
709 
710  // The reduction is done to the first element.
711  return pfirst<PacketXf>(prod);
712 }
713 
714 template <>
715 EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a)
716 {
717  return svminv_f32(svptrue_b32(), a);
718 }
719 
720 template <>
721 EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a)
722 {
723  return svmaxv_f32(svptrue_b32(), a);
724 }
725 
726 template<int N>
727 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel)
728 {
729  float buffer[packet_traits<float>::size * N] = {0};
730  int i = 0;
731 
732  PacketXi stride_index = svindex_s32(0, N);
733 
734  for (i = 0; i < N; i++) {
735  svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
736  }
737 
738  for (i = 0; i < N; i++) {
739  kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
740  }
741 }
742 
743 template<>
744 EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent)
745 {
746  return pldexp_generic(a, exponent);
747 }
748 
749 } // namespace internal
750 } // namespace Eigen
751 
752 #endif // EIGEN_PACKET_MATH_SVE_H
Array< int, 3, 1 > b
Array33i c
#define EIGEN_DEBUG_ALIGNED_STORE
#define EIGEN_DEBUG_ALIGNED_LOAD
#define EIGEN_DEBUG_UNALIGNED_STORE
#define EIGEN_DEBUG_UNALIGNED_LOAD
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:883
#define EIGEN_FAST_MATH
Definition: Macros.h:50
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
@ Aligned64
Definition: Constants.h:239
PacketXf pmax< PropagateNumbers, PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXf pgather< float, PacketXf >(const float *from, Index stride)
void pstore< float >(float *to, const Packet4f &from)
PacketXi pand< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pmin< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXi pzero< PacketXi >(const PacketXi &)
PacketXi psub< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf psub< PacketXf >(const PacketXf &a, const PacketXf &b)
numext::int32_t predux_mul< PacketXi >(const PacketXi &a)
PacketXi pmax< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pmul< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXi plset< PacketXi >(const numext::int32_t &a)
PacketXf pmax< PropagateNaN, PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXf pmin< PropagateNaN, PacketXf >(const PacketXf &a, const PacketXf &b)
Packet4f pabs(const Packet4f &a)
PacketXf pmax< PacketXf >(const PacketXf &a, const PacketXf &b)
float predux_min< PacketXf >(const PacketXf &a)
Packet2cf pnegate(const Packet2cf &a)
PacketXf pset1< PacketXf >(const float &from)
PacketXf pcmp_lt_or_nan< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXi por< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pset1frombits< PacketXf >(numext::uint32_t from)
Packet4i plogical_shift_right(const Packet4i &a)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
PacketXf ploadquad< PacketXf >(const float *from)
float predux_max< PacketXf >(const PacketXf &a)
PacketXi pcmp_lt< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pldexp< PacketXf >(const PacketXf &a, const PacketXf &exponent)
PacketXi pcmp_eq< PacketXi >(const PacketXi &a, const PacketXi &b)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
PacketXf pfrexp< PacketXf >(const PacketXf &a, PacketXf &exponent)
float predux_mul< PacketXf >(const PacketXf &a)
Packet pfrexp_generic(const Packet &a, Packet &exponent)
Packet pldexp_generic(const Packet &a, const Packet &exponent)
PacketXf pcmp_eq< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXi ploadquad< PacketXi >(const numext::int32_t *from)
PacketXi pset1< PacketXi >(const numext::int32_t &from)
PacketXi pcmp_le< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pcmp_le< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXi pandnot< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXi pdiv< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pmin< PropagateNumbers, PacketXf >(const PacketXf &a, const PacketXf &b)
void pstoreu< float >(float *to, const Packet4f &from)
PacketXi ptrue< PacketXi >(const PacketXi &)
PacketXf pxor< PacketXf >(const PacketXf &a, const PacketXf &b)
float pfirst< PacketXf >(const PacketXf &a)
PacketXf ploaddup< PacketXf >(const float *from)
PacketXf padd< PacketXf >(const PacketXf &a, const PacketXf &b)
numext::int32_t pfirst< PacketXi >(const PacketXi &a)
Packet2cf pconj(const Packet2cf &a)
PacketXf plset< PacketXf >(const float &a)
PacketXf pfloor< PacketXf >(const PacketXf &a)
PacketXf ptrue< PacketXf >(const PacketXf &)
PacketXi padd< PacketXi >(const PacketXi &a, const PacketXi &b)
Packet4i plogical_shift_left(const Packet4i &a)
Packet2cf preverse(const Packet2cf &a)
PacketXi pmul< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXi ploaddup< PacketXi >(const numext::int32_t *from)
Packet4i parithmetic_shift_right(const Packet4i &a)
PacketXi pload< PacketXi >(const numext::int32_t *from)
svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)))
numext::int32_t predux< PacketXi >(const PacketXi &a)
void pscatter< float, PacketXf >(float *to, const PacketXf &from, Index stride)
PacketXf pload< PacketXf >(const float *from)
PacketXi pmin< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pand< PacketXf >(const PacketXf &a, const PacketXf &b)
float predux< PacketXf >(const PacketXf &a)
PacketXf pandnot< PacketXf >(const PacketXf &a, const PacketXf &b)
numext::int32_t predux_max< PacketXi >(const PacketXi &a)
PacketXf por< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXf pdiv< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXi pxor< PacketXi >(const PacketXi &a, const PacketXi &b)
PacketXf pcmp_lt< PacketXf >(const PacketXf &a, const PacketXf &b)
PacketXi ploadu< PacketXi >(const numext::int32_t *from)
numext::int32_t predux_min< PacketXi >(const PacketXi &a)
PacketXf ploadu< PacketXf >(const float *from)
std::int32_t int32_t
Definition: Meta.h:40
std::uint32_t uint32_t
Definition: Meta.h:39
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82