37 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP
38 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP
43 namespace TensorSycl {
46 #ifndef EIGEN_SYCL_MAX_GLOBAL_RANGE
47 #define EIGEN_SYCL_MAX_GLOBAL_RANGE (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 * 4)
50 template <
typename index_t>
51 struct ScanParameters {
54 const index_t total_size;
55 const index_t non_scan_size;
56 const index_t scan_size;
57 const index_t non_scan_stride;
58 const index_t scan_stride;
59 const index_t panel_threads;
60 const index_t group_threads;
61 const index_t block_threads;
62 const index_t elements_per_group;
63 const index_t elements_per_block;
64 const index_t loop_range;
66 ScanParameters(index_t total_size_, index_t non_scan_size_, index_t scan_size_, index_t non_scan_stride_,
67 index_t scan_stride_, index_t panel_threads_, index_t group_threads_, index_t block_threads_,
68 index_t elements_per_group_, index_t elements_per_block_, index_t loop_range_)
69 : total_size(total_size_),
70 non_scan_size(non_scan_size_),
71 scan_size(scan_size_),
72 non_scan_stride(non_scan_stride_),
73 scan_stride(scan_stride_),
74 panel_threads(panel_threads_),
75 group_threads(group_threads_),
76 block_threads(block_threads_),
77 elements_per_group(elements_per_group_),
78 elements_per_block(elements_per_block_),
79 loop_range(loop_range_) {}
83 template <
typename Evaluator,
typename CoeffReturnType,
typename OutAccessor,
typename Op,
typename Index,
85 struct ScanKernelFunctor {
86 typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
88 static EIGEN_CONSTEXPR int PacketSize = ScanParameters<Index>::ScanPerThread / 2;
90 LocalAccessor scratch;
94 const ScanParameters<Index> scanParameters;
97 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScanKernelFunctor(LocalAccessor scratch_,
const Evaluator dev_eval_,
98 OutAccessor out_accessor_, OutAccessor temp_accessor_,
99 const ScanParameters<Index> scanParameters_, Op accumulator_,
100 const bool inclusive_)
103 out_ptr(out_accessor_),
104 tmp_ptr(temp_accessor_),
105 scanParameters(scanParameters_),
106 accumulator(accumulator_),
107 inclusive(inclusive_) {}
109 template <scan_step sst = stp,
typename Input>
112 read(
const Input &inpt,
Index global_id)
const {
113 return inpt.coeff(global_id);
116 template <scan_step sst = stp,
typename Input>
119 read(
const Input &inpt,
Index global_id)
const {
120 return inpt[global_id];
123 template <scan_step sst = stp,
typename InclusiveOp>
125 first_step_inclusive_Operation(InclusiveOp inclusive_op)
const {
129 template <scan_step sst = stp,
typename InclusiveOp>
131 first_step_inclusive_Operation(InclusiveOp)
const {}
136 for (
Index loop_offset = 0; loop_offset < scanParameters.loop_range; loop_offset++) {
137 Index data_offset = (itemID.get_global_id(0) + (itemID.get_global_range(0) * loop_offset));
138 Index tmp = data_offset % scanParameters.panel_threads;
139 const Index panel_id = data_offset / scanParameters.panel_threads;
140 const Index group_id = tmp / scanParameters.group_threads;
141 tmp = tmp % scanParameters.group_threads;
142 const Index block_id = tmp / scanParameters.block_threads;
143 const Index local_id = tmp % scanParameters.block_threads;
145 const Index scratch_stride = scanParameters.elements_per_block / PacketSize;
146 const Index scratch_offset = (itemID.get_local_id(0) / scanParameters.block_threads) * scratch_stride;
147 CoeffReturnType private_scan[ScanParameters<Index>::ScanPerThread];
148 CoeffReturnType inclusive_scan;
151 const Index panel_offset = panel_id * scanParameters.scan_size * scanParameters.non_scan_size;
152 const Index group_offset = group_id * scanParameters.non_scan_stride;
154 const Index block_offset = block_id * scanParameters.elements_per_block * scanParameters.scan_stride;
155 const Index thread_offset = (ScanParameters<Index>::ScanPerThread * local_id * scanParameters.scan_stride);
156 const Index global_offset = panel_offset + group_offset + block_offset + thread_offset;
157 Index next_elements = 0;
159 for (
int i = 0; i < ScanParameters<Index>::ScanPerThread;
i++) {
160 Index global_id = global_offset + next_elements;
161 private_scan[
i] = ((((block_id * scanParameters.elements_per_block) +
162 (ScanParameters<Index>::ScanPerThread * local_id) +
i) < scanParameters.scan_size) &&
163 (global_id < scanParameters.total_size))
164 ?
read(dev_eval, global_id)
165 : accumulator.initialize();
166 next_elements += scanParameters.scan_stride;
168 first_step_inclusive_Operation([&]() EIGEN_DEVICE_FUNC {
170 inclusive_scan = private_scan[ScanParameters<Index>::ScanPerThread - 1];
175 for (
int packetIndex = 0; packetIndex < ScanParameters<Index>::ScanPerThread; packetIndex += PacketSize) {
176 Index private_offset = 1;
179 for (
Index d = PacketSize >> 1; d > 0; d >>= 1) {
181 for (
Index l = 0; l < d; l++) {
182 Index ai = private_offset * (2 * l + 1) - 1 + packetIndex;
183 Index bi = private_offset * (2 * l + 2) - 1 + packetIndex;
184 CoeffReturnType accum = accumulator.initialize();
185 accumulator.reduce(private_scan[ai], &accum);
186 accumulator.reduce(private_scan[bi], &accum);
187 private_scan[bi] = accumulator.finalize(accum);
191 scratch[2 * local_id + (packetIndex / PacketSize) + scratch_offset] =
192 private_scan[PacketSize - 1 + packetIndex];
193 private_scan[PacketSize - 1 + packetIndex] = accumulator.initialize();
196 for (
Index d = 1; d < PacketSize; d *= 2) {
197 private_offset >>= 1;
199 for (
Index l = 0; l < d; l++) {
200 Index ai = private_offset * (2 * l + 1) - 1 + packetIndex;
201 Index bi = private_offset * (2 * l + 2) - 1 + packetIndex;
202 CoeffReturnType accum = accumulator.initialize();
203 accumulator.reduce(private_scan[ai], &accum);
204 accumulator.reduce(private_scan[bi], &accum);
205 private_scan[ai] = private_scan[bi];
206 private_scan[bi] = accumulator.finalize(accum);
213 for (
Index d = scratch_stride >> 1; d > 0; d >>= 1) {
215 itemID.barrier(cl::sycl::access::fence_space::local_space);
217 Index ai = offset * (2 * local_id + 1) - 1 + scratch_offset;
218 Index bi = offset * (2 * local_id + 2) - 1 + scratch_offset;
219 CoeffReturnType accum = accumulator.initialize();
220 accumulator.reduce(scratch[ai], &accum);
221 accumulator.reduce(scratch[bi], &accum);
222 scratch[bi] = accumulator.finalize(accum);
227 itemID.barrier(cl::sycl::access::fence_space::local_space);
230 if (((scanParameters.elements_per_group / scanParameters.elements_per_block) > 1)) {
231 const Index temp_id = panel_id * (scanParameters.elements_per_group / scanParameters.elements_per_block) *
232 scanParameters.non_scan_size +
233 group_id * (scanParameters.elements_per_group / scanParameters.elements_per_block) +
235 tmp_ptr[temp_id] = scratch[scratch_stride - 1 + scratch_offset];
238 scratch[scratch_stride - 1 + scratch_offset] = accumulator.initialize();
241 for (
Index d = 1; d < scratch_stride; d *= 2) {
244 itemID.barrier(cl::sycl::access::fence_space::local_space);
246 Index ai = offset * (2 * local_id + 1) - 1 + scratch_offset;
247 Index bi = offset * (2 * local_id + 2) - 1 + scratch_offset;
248 CoeffReturnType accum = accumulator.initialize();
249 accumulator.reduce(scratch[ai], &accum);
250 accumulator.reduce(scratch[bi], &accum);
251 scratch[ai] = scratch[bi];
252 scratch[bi] = accumulator.finalize(accum);
256 itemID.barrier(cl::sycl::access::fence_space::local_space);
259 for (
int packetIndex = 0; packetIndex < ScanParameters<Index>::ScanPerThread; packetIndex += PacketSize) {
261 for (
Index i = 0;
i < PacketSize;
i++) {
262 CoeffReturnType accum = private_scan[packetIndex +
i];
263 accumulator.reduce(scratch[2 * local_id + (packetIndex / PacketSize) + scratch_offset], &accum);
264 private_scan[packetIndex +
i] = accumulator.finalize(accum);
267 first_step_inclusive_Operation([&]() EIGEN_DEVICE_FUNC {
269 accumulator.reduce(private_scan[ScanParameters<Index>::ScanPerThread - 1], &inclusive_scan);
270 private_scan[0] = accumulator.finalize(inclusive_scan);
276 for (
Index i = 0; i < ScanParameters<Index>::ScanPerThread;
i++) {
277 Index global_id = global_offset + next_elements;
278 if ((((block_id * scanParameters.elements_per_block) + (ScanParameters<Index>::ScanPerThread * local_id) + i) <
279 scanParameters.scan_size) &&
280 (global_id < scanParameters.total_size)) {
281 Index private_id = (
i * !inclusive) + (((i + 1) % ScanParameters<Index>::ScanPerThread) * (inclusive));
282 out_ptr[global_id] = private_scan[private_id];
284 next_elements += scanParameters.scan_stride;
290 template <
typename CoeffReturnType,
typename InAccessor,
typename OutAccessor,
typename Op,
typename Index>
291 struct ScanAdjustmentKernelFunctor {
292 typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
294 static EIGEN_CONSTEXPR int PacketSize = ScanParameters<Index>::ScanPerThread / 2;
297 const ScanParameters<Index> scanParameters;
299 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScanAdjustmentKernelFunctor(LocalAccessor, InAccessor in_accessor_,
300 OutAccessor out_accessor_,
301 const ScanParameters<Index> scanParameters_,
303 : in_ptr(in_accessor_),
304 out_ptr(out_accessor_),
305 scanParameters(scanParameters_),
306 accumulator(accumulator_) {}
310 for (
Index loop_offset = 0; loop_offset < scanParameters.loop_range; loop_offset++) {
311 Index data_offset = (itemID.get_global_id(0) + (itemID.get_global_range(0) * loop_offset));
312 Index tmp = data_offset % scanParameters.panel_threads;
313 const Index panel_id = data_offset / scanParameters.panel_threads;
314 const Index group_id = tmp / scanParameters.group_threads;
315 tmp = tmp % scanParameters.group_threads;
316 const Index block_id = tmp / scanParameters.block_threads;
317 const Index local_id = tmp % scanParameters.block_threads;
321 const Index panel_offset = panel_id * scanParameters.scan_size * scanParameters.non_scan_size;
322 const Index group_offset = group_id * scanParameters.non_scan_stride;
324 const Index block_offset = block_id * scanParameters.elements_per_block * scanParameters.scan_stride;
325 const Index thread_offset = ScanParameters<Index>::ScanPerThread * local_id * scanParameters.scan_stride;
327 const Index global_offset = panel_offset + group_offset + block_offset + thread_offset;
328 const Index block_size = scanParameters.elements_per_group / scanParameters.elements_per_block;
329 const Index in_id = (panel_id * block_size * scanParameters.non_scan_size) + (group_id * block_size) + block_id;
330 CoeffReturnType adjust_val = in_ptr[in_id];
332 Index next_elements = 0;
334 for (
Index i = 0; i < ScanParameters<Index>::ScanPerThread;
i++) {
335 Index global_id = global_offset + next_elements;
336 if ((((block_id * scanParameters.elements_per_block) + (ScanParameters<Index>::ScanPerThread * local_id) + i) <
337 scanParameters.scan_size) &&
338 (global_id < scanParameters.total_size)) {
339 CoeffReturnType accum = adjust_val;
340 accumulator.reduce(out_ptr[global_id], &accum);
341 out_ptr[global_id] = accumulator.finalize(accum);
343 next_elements += scanParameters.scan_stride;
349 template <
typename Index>
351 const Index &total_size;
352 const Index &scan_size;
353 const Index &panel_size;
354 const Index &non_scan_size;
355 const Index &scan_stride;
356 const Index &non_scan_stride;
358 Index max_elements_per_block;
363 Index elements_per_group;
364 Index elements_per_block;
368 const Eigen::SyclDevice &dev;
369 EIGEN_STRONG_INLINE ScanInfo(
const Index &total_size_,
const Index &scan_size_,
const Index &panel_size_,
370 const Index &non_scan_size_,
const Index &scan_stride_,
const Index &non_scan_stride_,
371 const Eigen::SyclDevice &dev_)
372 : total_size(total_size_),
373 scan_size(scan_size_),
374 panel_size(panel_size_),
375 non_scan_size(non_scan_size_),
376 scan_stride(scan_stride_),
377 non_scan_stride(non_scan_stride_),
380 local_range =
std::min(
Index(dev.getNearestPowerOfTwoWorkGroupSize()),
381 Index(EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1));
383 max_elements_per_block = local_range * ScanParameters<Index>::ScanPerThread;
386 dev.getPowerOfTwo(
Index(roundUp(
Index(scan_size), ScanParameters<Index>::ScanPerThread)),
true);
387 const Index elements_per_panel = elements_per_group * non_scan_size;
388 elements_per_block =
std::min(
Index(elements_per_group),
Index(max_elements_per_block));
389 panel_threads = elements_per_panel / ScanParameters<Index>::ScanPerThread;
390 group_threads = elements_per_group / ScanParameters<Index>::ScanPerThread;
391 block_threads = elements_per_block / ScanParameters<Index>::ScanPerThread;
392 block_size = elements_per_group / elements_per_block;
393 #ifdef EIGEN_SYCL_MAX_GLOBAL_RANGE
396 const Index max_threads = panel_threads * panel_size;
398 global_range = roundUp(max_threads, local_range);
400 std::ceil(
double(elements_per_panel * panel_size) / (global_range * ScanParameters<Index>::ScanPerThread)));
402 inline ScanParameters<Index> get_scan_parameter() {
403 return ScanParameters<Index>(total_size, non_scan_size, scan_size, non_scan_stride, scan_stride, panel_threads,
404 group_threads, block_threads, elements_per_group, elements_per_block, loop_range);
406 inline cl::sycl::nd_range<1> get_thread_range() {
407 return cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
411 template <
typename EvaluatorPo
interType,
typename CoeffReturnType,
typename Reducer,
typename Index>
412 struct SYCLAdjustBlockOffset {
413 EIGEN_STRONG_INLINE
static void adjust_scan_block_offset(EvaluatorPointerType in_ptr, EvaluatorPointerType out_ptr,
414 Reducer &accumulator,
const Index total_size,
415 const Index scan_size,
const Index panel_size,
416 const Index non_scan_size,
const Index scan_stride,
417 const Index non_scan_stride,
const Eigen::SyclDevice &dev) {
419 ScanInfo<Index>(total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride, dev);
421 typedef ScanAdjustmentKernelFunctor<CoeffReturnType, EvaluatorPointerType, EvaluatorPointerType, Reducer, Index>
423 dev.template unary_kernel_launcher<CoeffReturnType, AdjustFuctor>(in_ptr, out_ptr, scan_info.get_thread_range(),
424 scan_info.max_elements_per_block,
425 scan_info.get_scan_parameter(), accumulator).wait();
429 template <
typename CoeffReturnType, scan_step stp>
430 struct ScanLauncher_impl {
431 template <
typename Input,
typename EvaluatorPo
interType,
typename Reducer,
typename Index>
432 EIGEN_STRONG_INLINE
static void scan_block(Input in_ptr, EvaluatorPointerType out_ptr, Reducer &accumulator,
434 const Index non_scan_size,
const Index scan_stride,
435 const Index non_scan_stride,
const bool inclusive,
436 const Eigen::SyclDevice &dev) {
438 ScanInfo<Index>(total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride, dev);
439 const Index temp_pointer_size = scan_info.block_size * non_scan_size * panel_size;
440 const Index scratch_size = scan_info.max_elements_per_block / (ScanParameters<Index>::ScanPerThread / 2);
441 CoeffReturnType *temp_pointer =
442 static_cast<CoeffReturnType *
>(dev.allocate_temp(temp_pointer_size *
sizeof(CoeffReturnType)));
443 EvaluatorPointerType tmp_global_accessor = dev.get(temp_pointer);
445 typedef ScanKernelFunctor<Input, CoeffReturnType, EvaluatorPointerType, Reducer, Index, stp> ScanFunctor;
446 dev.template binary_kernel_launcher<CoeffReturnType, ScanFunctor>(
447 in_ptr, out_ptr, tmp_global_accessor, scan_info.get_thread_range(), scratch_size,
448 scan_info.get_scan_parameter(), accumulator, inclusive).wait();
450 if (scan_info.block_size > 1) {
451 ScanLauncher_impl<CoeffReturnType, scan_step::second>::scan_block(
452 tmp_global_accessor, tmp_global_accessor, accumulator, temp_pointer_size, scan_info.block_size, panel_size,
453 non_scan_size,
Index(1), scan_info.block_size,
false, dev);
455 SYCLAdjustBlockOffset<EvaluatorPointerType, CoeffReturnType, Reducer, Index>::adjust_scan_block_offset(
456 tmp_global_accessor, out_ptr, accumulator, total_size, scan_size, panel_size, non_scan_size, scan_stride,
457 non_scan_stride, dev);
459 dev.deallocate_temp(temp_pointer);
466 template <
typename Self,
typename Reducer,
bool vectorize>
467 struct ScanLauncher<Self, Reducer,
Eigen::SyclDevice, vectorize> {
468 typedef typename Self::Index
Index;
469 typedef typename Self::CoeffReturnType CoeffReturnType;
470 typedef typename Self::Storage Storage;
471 typedef typename Self::EvaluatorPointerType EvaluatorPointerType;
472 void operator()(Self &
self, EvaluatorPointerType data)
const {
474 const Index scan_size =
self.size();
475 const Index scan_stride =
self.stride();
477 auto accumulator =
self.accumulator();
478 auto inclusive = !
self.exclusive();
479 auto consume_dim =
self.consume_dim();
480 auto dev =
self.device();
482 auto dims =
self.inner().dimensions();
484 Index non_scan_size = 1;
485 Index panel_size = 1;
486 if (
static_cast<int>(Self::Layout) ==
static_cast<int>(
ColMajor)) {
487 for (
int i = 0;
i < consume_dim;
i++) {
488 non_scan_size *= dims[
i];
490 for (
int i = consume_dim + 1;
i < Self::NumDims;
i++) {
491 panel_size *= dims[
i];
494 for (
int i = Self::NumDims - 1;
i > consume_dim;
i--) {
495 non_scan_size *= dims[
i];
497 for (
int i = consume_dim - 1;
i >= 0;
i--) {
498 panel_size *= dims[
i];
501 const Index non_scan_stride = (scan_stride > 1) ? 1 : scan_size;
502 auto eval_impl =
self.inner();
503 TensorSycl::internal::ScanLauncher_impl<CoeffReturnType, TensorSycl::internal::scan_step::first>::scan_block(
504 eval_impl, data, accumulator, total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride,
IndexedView_or_VectorBlock operator()(const Indices &indices)
#define EIGEN_UNROLL_LOOP
#define EIGEN_DEVICE_FUNC
#define EIGEN_SYCL_MAX_GLOBAL_RANGE
static std::enable_if_t< PacketLoad, PacketType > read(const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &ld)
read, a template function used for loading the data from global memory. This function is used to guar...
constexpr auto array_prod(const array< T, N > &arr) -> decltype(array_reduce< product_op, T, N >(arr, static_cast< T >(1)))
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
CleanedUpDerType< DerType >::type() min(const AutoDiffScalar< DerType > &x, const T &y)