21 #ifndef EIGEN_PACKET_MATH_SYCL_H
22 #define EIGEN_PACKET_MATH_SYCL_H
23 #include <type_traits>
25 #include "../../InternalHeaderCheck.h"
30 #ifdef SYCL_DEVICE_ONLY
31 #define SYCL_PLOAD(packet_type, AlignedType) \
33 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type \
34 pload##AlignedType<packet_type>( \
35 const typename unpacket_traits<packet_type>::type* from) { \
36 auto ptr = cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>(from);\
42 SYCL_PLOAD(cl::sycl::cl_float4, u)
43 SYCL_PLOAD(cl::sycl::cl_float4, )
44 SYCL_PLOAD(cl::sycl::cl_double2, u)
45 SYCL_PLOAD(cl::sycl::cl_double2, )
49 #define SYCL_PSTORE(scalar, packet_type, alignment) \
51 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \
52 scalar* to, const packet_type& from) { \
53 auto ptr = cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>(to);\
57 SYCL_PSTORE(
float, cl::sycl::cl_float4, )
58 SYCL_PSTORE(
float, cl::sycl::cl_float4, u)
59 SYCL_PSTORE(
double, cl::sycl::cl_double2, )
60 SYCL_PSTORE(
double, cl::sycl::cl_double2, u)
64 #define SYCL_PSET1(packet_type) \
66 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pset1<packet_type>( \
67 const typename unpacket_traits<packet_type>::type& from) { \
68 return packet_type(from); \
72 SYCL_PSET1(cl::sycl::cl_float4)
73 SYCL_PSET1(cl::sycl::cl_double2)
77 template <
typename packet_type>
78 struct get_base_packet {
79 template <
typename sycl_multi_po
inter>
81 get_ploaddup(sycl_multi_pointer) {}
83 template <
typename sycl_multi_po
inter>
85 get_pgather(sycl_multi_pointer,
Index) {}
89 struct get_base_packet<cl::sycl::cl_float4> {
90 template <
typename sycl_multi_po
inter>
92 sycl_multi_pointer from) {
93 return cl::sycl::cl_float4(from[0], from[0], from[1], from[1]);
95 template <
typename sycl_multi_po
inter>
97 sycl_multi_pointer from,
Index stride) {
98 return cl::sycl::cl_float4(from[0 * stride], from[1 * stride],
99 from[2 * stride], from[3 * stride]);
102 template <
typename sycl_multi_po
inter>
104 sycl_multi_pointer to,
const cl::sycl::cl_float4& from,
Index stride) {
108 to[tmp += stride] = from.z();
109 to[tmp += stride] = from.w();
113 return cl::sycl::cl_float4(
static_cast<float>(
a),
static_cast<float>(
a + 1),
114 static_cast<float>(
a + 2),
115 static_cast<float>(
a + 3));
120 struct get_base_packet<cl::sycl::cl_double2> {
121 template <
typename sycl_multi_po
inter>
123 get_ploaddup(
const sycl_multi_pointer from) {
124 return cl::sycl::cl_double2(from[0], from[0]);
127 template <
typename sycl_multi_po
inter,
typename Index>
129 const sycl_multi_pointer from,
Index stride) {
130 return cl::sycl::cl_double2(from[0 * stride], from[1 * stride]);
133 template <
typename sycl_multi_po
inter>
135 sycl_multi_pointer to,
const cl::sycl::cl_double2& from,
Index stride) {
137 to[stride] = from.y();
142 return cl::sycl::cl_double2(
static_cast<double>(
a),
143 static_cast<double>(
a + 1));
147 #define SYCL_PLOAD_DUP_SPECILIZE(packet_type) \
149 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup<packet_type>( \
150 const typename unpacket_traits<packet_type>::type* from) { \
151 return get_base_packet<packet_type>::get_ploaddup(from); \
154 SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_float4)
155 SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_double2)
157 #undef SYCL_PLOAD_DUP_SPECILIZE
159 #define SYCL_PLSET(packet_type) \
161 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type plset<packet_type>( \
162 const typename unpacket_traits<packet_type>::type& a) { \
163 return get_base_packet<packet_type>::set_plset(a); \
165 SYCL_PLSET(cl::sycl::cl_float4)
166 SYCL_PLSET(cl::sycl::cl_double2)
170 #define SYCL_PGATHER_SPECILIZE(scalar, packet_type) \
172 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type \
173 pgather<scalar, packet_type>( \
174 const typename unpacket_traits<packet_type>::type* from, Index stride) { \
175 return get_base_packet<packet_type>::get_pgather(from, stride); \
178 SYCL_PGATHER_SPECILIZE(
float, cl::sycl::cl_float4)
179 SYCL_PGATHER_SPECILIZE(
double, cl::sycl::cl_double2)
181 #undef SYCL_PGATHER_SPECILIZE
183 #define SYCL_PSCATTER_SPECILIZE(scalar, packet_type) \
185 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<scalar, packet_type>( \
186 typename unpacket_traits<packet_type>::type * to, \
187 const packet_type& from, Index stride) { \
188 get_base_packet<packet_type>::set_pscatter(to, from, stride); \
191 SYCL_PSCATTER_SPECILIZE(
float, cl::sycl::cl_float4)
192 SYCL_PSCATTER_SPECILIZE(
double, cl::sycl::cl_double2)
194 #undef SYCL_PSCATTER_SPECILIZE
196 #define SYCL_PMAD(packet_type) \
198 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pmadd( \
199 const packet_type& a, const packet_type& b, const packet_type& c) { \
200 return cl::sycl::mad(a, b, c); \
203 SYCL_PMAD(cl::sycl::cl_float4)
204 SYCL_PMAD(cl::sycl::cl_double2)
209 const cl::sycl::cl_float4&
a) {
214 const cl::sycl::cl_double2&
a) {
220 const cl::sycl::cl_float4&
a) {
221 return a.x() +
a.y() +
a.z() +
a.w();
226 const cl::sycl::cl_double2&
a) {
227 return a.x() +
a.y();
232 const cl::sycl::cl_float4&
a) {
238 const cl::sycl::cl_double2&
a) {
244 const cl::sycl::cl_float4&
a) {
250 const cl::sycl::cl_double2&
a) {
256 const cl::sycl::cl_float4&
a) {
257 return a.x() *
a.y() *
a.z() *
a.w();
261 const cl::sycl::cl_double2&
a) {
262 return a.x() *
a.y();
267 pabs<cl::sycl::cl_float4>(
const cl::sycl::cl_float4&
a) {
268 return cl::sycl::cl_float4(cl::sycl::fabs(
a.x()), cl::sycl::fabs(
a.y()),
269 cl::sycl::fabs(
a.z()), cl::sycl::fabs(
a.w()));
273 pabs<cl::sycl::cl_double2>(
const cl::sycl::cl_double2&
a) {
274 return cl::sycl::cl_double2(cl::sycl::fabs(
a.x()), cl::sycl::fabs(
a.y()));
277 template <
typename Packet>
280 return (
a <=
b).template as<Packet>();
283 template <
typename Packet>
286 return (
a <
b).template as<Packet>();
289 template <
typename Packet>
292 return (
a ==
b).template as<Packet>();
295 #define SYCL_PCMP(OP, TYPE) \
297 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE TYPE pcmp_##OP<TYPE>(const TYPE &a, \
299 return sycl_pcmp_##OP<TYPE>(a, b); \
302 SYCL_PCMP(le, cl::sycl::cl_float4)
303 SYCL_PCMP(lt, cl::sycl::cl_float4)
304 SYCL_PCMP(eq, cl::sycl::cl_float4)
305 SYCL_PCMP(le, cl::sycl::cl_double2)
306 SYCL_PCMP(lt, cl::sycl::cl_double2)
307 SYCL_PCMP(eq, cl::sycl::cl_double2)
311 PacketBlock<cl::sycl::cl_float4, 4>& kernel) {
312 float tmp = kernel.packet[0].y();
313 kernel.packet[0].y() = kernel.packet[1].x();
314 kernel.packet[1].x() = tmp;
316 tmp = kernel.packet[0].z();
317 kernel.packet[0].z() = kernel.packet[2].x();
318 kernel.packet[2].x() = tmp;
320 tmp = kernel.packet[0].w();
321 kernel.packet[0].w() = kernel.packet[3].x();
322 kernel.packet[3].x() = tmp;
324 tmp = kernel.packet[1].z();
325 kernel.packet[1].z() = kernel.packet[2].y();
326 kernel.packet[2].y() = tmp;
328 tmp = kernel.packet[1].w();
329 kernel.packet[1].w() = kernel.packet[3].y();
330 kernel.packet[3].y() = tmp;
332 tmp = kernel.packet[2].w();
333 kernel.packet[2].w() = kernel.packet[3].z();
334 kernel.packet[3].z() = tmp;
338 PacketBlock<cl::sycl::cl_double2, 2>& kernel) {
339 double tmp = kernel.packet[0].y();
340 kernel.packet[0].y() = kernel.packet[1].x();
341 kernel.packet[1].x() = tmp;
347 const cl::sycl::cl_float4& thenPacket,
348 const cl::sycl::cl_float4& elsePacket) {
349 cl::sycl::cl_int4 condition(
350 ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1,
351 ifPacket.select[2] ? 0 : -1, ifPacket.select[3] ? 0 : -1);
352 return cl::sycl::select(thenPacket, elsePacket, condition);
356 inline cl::sycl::cl_double2
pblend(
358 const cl::sycl::cl_double2& thenPacket,
359 const cl::sycl::cl_double2& elsePacket) {
360 cl::sycl::cl_long2 condition(ifPacket.select[0] ? 0 : -1,
361 ifPacket.select[1] ? 0 : -1);
362 return cl::sycl::select(thenPacket, elsePacket, condition);
#define EIGEN_ALWAYS_INLINE
#define EIGEN_DEVICE_FUNC
bfloat16 fmax(const bfloat16 &a, const bfloat16 &b)
bfloat16 fmin(const bfloat16 &a, const bfloat16 &b)
void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Packet4i pblend(const Selector< 4 > &ifPacket, const Packet4i &thenPacket, const Packet4i &elsePacket)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.