12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
17 #include "./InternalHeaderCheck.h"
21 template<
typename Scalar,
typename Index,
typename LhsMapper,
22 typename RhsMapper,
typename OutputMapper,
bool needs_edge_check>
23 __device__ EIGEN_STRONG_INLINE
void
24 EigenContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
25 const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
28 const Index m_block_idx = blockIdx.x;
29 const Index n_block_idx = blockIdx.y;
31 const Index base_m = 64 * m_block_idx;
32 const Index base_n = 64 * n_block_idx;
69 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
70 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
72 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
73 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
74 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
75 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
76 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
77 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
78 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
79 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
81 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
82 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
83 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
84 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
85 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
86 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
87 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
88 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
99 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
100 const Index lhs_vert = base_m + load_idx_vert;
102 #define prefetchIntoRegisters(base_k) \
122 if (!needs_edge_check || lhs_vert < m_size) { \
123 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
124 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
125 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
126 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
127 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
128 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
129 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
130 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
132 if (!needs_edge_check || lhs_horiz_7 < k_size) { \
133 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
134 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
135 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
136 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
137 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
138 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
139 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
140 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
141 } else if (lhs_horiz_6 < k_size) { \
142 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
143 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
144 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
145 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
146 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
147 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
148 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
149 } else if (lhs_horiz_5 < k_size) { \
150 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
151 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
152 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
153 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
154 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
155 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
156 } else if (lhs_horiz_4 < k_size) { \
157 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
158 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
159 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
160 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
161 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
162 } else if (lhs_horiz_3 < k_size) { \
163 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
164 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
165 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
166 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
167 } else if (lhs_horiz_2 < k_size) { \
168 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
169 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
170 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
171 } else if (lhs_horiz_1 < k_size) { \
172 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
173 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
174 } else if (lhs_horiz_0 < k_size) { \
175 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
179 const Index rhs_vert = base_k + load_idx_vert; \
180 if (!needs_edge_check || rhs_vert < k_size) { \
181 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
182 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
183 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
184 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
185 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
186 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
187 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
188 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
190 if (rhs_horiz_7 < n_size) { \
191 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
192 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
193 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
194 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
195 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
196 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
197 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
198 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
199 } else if (rhs_horiz_6 < n_size) { \
200 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
201 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
202 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
203 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
204 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
205 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
206 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
207 } else if (rhs_horiz_5 < n_size) { \
208 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
209 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
210 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
211 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
212 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
213 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
214 } else if (rhs_horiz_4 < n_size) { \
215 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
216 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
217 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
218 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
219 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
220 } else if (rhs_horiz_3 < n_size) { \
221 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
222 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
223 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
224 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
225 } else if (rhs_horiz_2 < n_size) { \
226 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
227 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
228 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
229 } else if (rhs_horiz_1 < n_size) { \
230 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
231 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
232 } else if (rhs_horiz_0 < n_size) { \
233 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
238 #define writeRegToShmem() \
239 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
240 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
242 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
243 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
245 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
246 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
248 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
249 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
251 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
252 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
254 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
255 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
257 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
258 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
260 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
261 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \
264 #define res(i, j) _res_##i##j
265 #define initResultRow(i) \
266 Scalar res(i, 0) = conv(0); \
267 Scalar res(i, 1) = conv(0); \
268 Scalar res(i, 2) = conv(0); \
269 Scalar res(i, 3) = conv(0); \
270 Scalar res(i, 4) = conv(0); \
271 Scalar res(i, 5) = conv(0); \
272 Scalar res(i, 6) = conv(0); \
273 Scalar res(i, 7) = conv(0); \
275 internal::scalar_cast_op<int, Scalar> conv;
286 for (
Index base_k = 0; base_k < k_size; base_k += 64) {
291 prefetchIntoRegisters(base_k);
294 #undef prefetchIntoRegisters
295 #undef writeRegToShmem
303 #define lcol(i) _lcol##i
313 #define rrow(j) _rrow##j
324 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
325 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
327 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
328 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
330 #define loadData(i, j) \
331 lcol(0) = lhs_element(0, j); \
332 rrow(0) = rhs_element(i, 0); \
333 lcol(1) = lhs_element(1, j); \
334 rrow(1) = rhs_element(i, 1); \
335 lcol(2) = lhs_element(2, j); \
336 rrow(2) = rhs_element(i, 2); \
337 lcol(3) = lhs_element(3, j); \
338 rrow(3) = rhs_element(i, 3); \
339 lcol(4) = lhs_element(4, j); \
340 rrow(4) = rhs_element(i, 4); \
341 lcol(5) = lhs_element(5, j); \
342 rrow(5) = rhs_element(i, 5); \
343 lcol(6) = lhs_element(6, j); \
344 rrow(6) = rhs_element(i, 6); \
345 lcol(7) = lhs_element(7, j); \
346 rrow(7) = rhs_element(i, 7); \
348 #define computeCol(j) \
349 res(0, j) += lcol(0) * rrow(j); \
350 res(1, j) += lcol(1) * rrow(j); \
351 res(2, j) += lcol(2) * rrow(j); \
352 res(3, j) += lcol(3) * rrow(j); \
353 res(4, j) += lcol(4) * rrow(j); \
354 res(5, j) += lcol(5) * rrow(j); \
355 res(6, j) += lcol(6) * rrow(j); \
356 res(7, j) += lcol(7) * rrow(j); \
358 #define computePass(i) \
393 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
396 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
399 #define reduceRow(i, mask) \
400 shuffleInc(i, 0, mask); \
401 shuffleInc(i, 1, mask); \
402 shuffleInc(i, 2, mask); \
403 shuffleInc(i, 3, mask); \
404 shuffleInc(i, 4, mask); \
405 shuffleInc(i, 5, mask); \
406 shuffleInc(i, 6, mask); \
407 shuffleInc(i, 7, mask); \
409 #define reduceMatrix(mask) \
410 reduceRow(0, mask); \
411 reduceRow(1, mask); \
412 reduceRow(2, mask); \
413 reduceRow(3, mask); \
414 reduceRow(4, mask); \
415 reduceRow(5, mask); \
416 reduceRow(6, mask); \
417 reduceRow(7, mask); \
444 #define writeResultShmem(i, j) \
445 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
447 #define writeRow(i) \
448 writeResultShmem(i, 0); \
449 writeResultShmem(i, 1); \
450 writeResultShmem(i, 2); \
451 writeResultShmem(i, 3); \
452 writeResultShmem(i, 4); \
453 writeResultShmem(i, 5); \
454 writeResultShmem(i, 6); \
455 writeResultShmem(i, 7); \
457 if (threadIdx.x == 0) {
467 #undef writeResultShmem
470 const int max_i_write =
numext::mini((
int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
471 const int max_j_write =
numext::mini((
int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
473 if (threadIdx.x < max_i_write) {
474 if (max_j_write == 8) {
476 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
477 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
478 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
479 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
480 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
481 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
482 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
483 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
487 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
488 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
489 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
490 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
491 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
492 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
495 for (
int j = 0;
j < max_j_write;
j++) {
496 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 *
j];
497 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
505 template<
typename Scalar,
typename Index,
typename LhsMapper,
506 typename RhsMapper,
typename OutputMapper>
508 #if defined(EIGEN_HIPCC)
509 __launch_bounds__(512, 1)
511 __launch_bounds__(512)
513 EigenContractionKernel(
const LhsMapper lhs,
const RhsMapper rhs,
514 const OutputMapper output,
516 __shared__ Scalar lhs_shmem[72 * 64];
517 __shared__ Scalar rhs_shmem[72 * 64];
519 const Index m_block_idx = blockIdx.x;
520 const Index n_block_idx = blockIdx.y;
522 const Index base_m = 64 * m_block_idx;
523 const Index base_n = 64 * n_block_idx;
525 if (base_m + 63 < m_size && base_n + 63 < n_size) {
526 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
528 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
533 template<
typename Index,
typename LhsMapper,
534 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
535 bool CHECK_RHS_BOUNDARY>
536 __device__ __forceinline__
void
537 EigenFloatContractionKernelInternal16x16(
const LhsMapper lhs,
const RhsMapper rhs,
538 const OutputMapper output, float2 lhs_shmem2[][16],
539 float2 rhs_shmem2[][8],
const Index m_size,
544 float4 lhs_pf0, rhs_pf0;
547 for (
int i=0;
i < 4;
i++) {
548 results[
i].x = results[
i].y = results[
i].z = results[
i].w = 0;
551 #define prefetch_lhs(reg, row, col) \
552 if (!CHECK_LHS_BOUNDARY) { \
553 if (col < k_size) { \
554 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
557 if (col < k_size) { \
558 if (row + 3 < m_size) { \
559 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
560 } else if (row + 2 < m_size) { \
561 reg.x =lhs(row + 0, col); \
562 reg.y =lhs(row + 1, col); \
563 reg.z =lhs(row + 2, col); \
564 } else if (row + 1 < m_size) { \
565 reg.x =lhs(row + 0, col); \
566 reg.y =lhs(row + 1, col); \
567 } else if (row < m_size) { \
568 reg.x =lhs(row + 0, col); \
573 Index lhs_vert = base_m+threadIdx.x*4;
575 for (
Index k = 0; k < k_size; k += 16) {
577 lhs_pf0 = internal::pset1<float4>(0);
578 rhs_pf0 = internal::pset1<float4>(0);
580 Index lhs_horiz = threadIdx.y+k;
581 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
583 Index rhs_vert = k+(threadIdx.x%4)*4;
584 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
586 if (!CHECK_RHS_BOUNDARY) {
587 if ((rhs_vert + 3) < k_size) {
589 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
590 }
else if (rhs_vert + 2 < k_size) {
592 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
593 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
594 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
595 }
else if (rhs_vert + 1 < k_size) {
596 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
597 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
598 }
else if (rhs_vert < k_size) {
599 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
602 if (rhs_horiz0 < n_size) {
603 if ((rhs_vert + 3) < k_size) {
604 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
605 }
else if ((rhs_vert + 2) < k_size) {
606 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
607 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
608 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
609 }
else if ((rhs_vert + 1) < k_size) {
610 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
611 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
612 }
else if (rhs_vert < k_size) {
613 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
619 if((threadIdx.x%8) < 4) {
626 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
627 x1 = __shfl_xor(x1, 4);
628 x2 = __shfl_xor(x2, 4);
630 x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
631 x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
633 if((threadIdx.x%8) < 4) {
648 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
649 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
658 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
659 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
662 #define add_vals(fl1, fl2, fr1, fr2)\
663 results[0].x += fl1.x * fr1.x;\
664 results[0].y += fl1.y * fr1.x;\
665 results[0].z += fl2.x * fr1.x;\
666 results[0].w += fl2.y * fr1.x;\
668 results[1].x += fl1.x * fr1.y;\
669 results[1].y += fl1.y * fr1.y;\
670 results[1].z += fl2.x * fr1.y;\
671 results[1].w += fl2.y * fr1.y;\
673 results[2].x += fl1.x * fr2.x;\
674 results[2].y += fl1.y * fr2.x;\
675 results[2].z += fl2.x * fr2.x;\
676 results[2].w += fl2.y * fr2.x;\
678 results[3].x += fl1.x * fr2.y;\
679 results[3].y += fl1.y * fr2.y;\
680 results[3].z += fl2.x * fr2.y;\
681 results[3].w += fl2.y * fr2.y;\
687 for (
int koff = 0; koff < 16; koff ++) {
689 float2 fl1 = lhs_shmem2[koff][threadIdx.x];
690 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
692 int start_feature = threadIdx.y * 4;
693 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
694 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
696 add_vals(fl1, fl2, fr1, fr2)
704 Index horiz_base = threadIdx.y*4+base_n;
705 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
706 for (
int i = 0;
i < 4;
i++) {
707 output(lhs_vert, horiz_base + i) = results[
i].x;
708 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
709 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
710 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
712 }
else if (!CHECK_RHS_BOUNDARY) {
714 if (lhs_vert + 3 < m_size) {
715 for (
int i = 0;
i < 4;
i++) {
716 output(lhs_vert, horiz_base + i) = results[
i].x;
717 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
718 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
719 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
721 }
else if (lhs_vert + 2 < m_size) {
722 for (
int i = 0;
i < 4;
i++) {
723 output(lhs_vert, horiz_base + i) = results[
i].x;
724 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
725 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
727 }
else if (lhs_vert + 1 < m_size) {
728 for (
int i = 0;
i < 4;
i++) {
729 output(lhs_vert, horiz_base + i) = results[
i].x;
730 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
732 }
else if (lhs_vert < m_size) {
733 for (
int i = 0;
i < 4;
i++) {
734 output(lhs_vert, horiz_base + i) = results[
i].x;
737 }
else if (!CHECK_LHS_BOUNDARY) {
747 for (
int i = 0;
i < 4;
i++) {
748 if (horiz_base+i < n_size) {
749 output(lhs_vert, horiz_base + i) = results[
i].x;
750 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
751 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
752 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
757 for (
int i = 0;
i < 4;
i++) {
758 if (horiz_base+i < n_size) {
759 if (lhs_vert < m_size)
760 output(lhs_vert, horiz_base + i) = results[
i].x;
761 if (lhs_vert + 1 < m_size)
762 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
763 if (lhs_vert + 2 < m_size)
764 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
765 if (lhs_vert + 3 < m_size)
766 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
773 template<
typename Index,
typename LhsMapper,
774 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
775 bool CHECK_RHS_BOUNDARY>
776 __device__ __forceinline__
void
777 EigenFloatContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
778 const OutputMapper output, float2 lhs_shmem2[][32],
779 float2 rhs_shmem2[][8],
const Index m_size,
784 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
785 float4 rhs_pf0, rhs_pf1;
788 for (
int i=0;
i < 8;
i++) {
789 results[
i].x = results[
i].y = results[
i].z = results[
i].w = 0;
792 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
793 for (
Index k = 0; k < k_size; k += 32) {
794 lhs_pf0 = internal::pset1<float4>(0);
795 lhs_pf1 = internal::pset1<float4>(0);
796 lhs_pf2 = internal::pset1<float4>(0);
797 lhs_pf3 = internal::pset1<float4>(0);
799 rhs_pf0 = internal::pset1<float4>(0);
800 rhs_pf1 = internal::pset1<float4>(0);
802 if (!CHECK_LHS_BOUNDARY) {
803 if ((threadIdx.y/4+k+24) < k_size) {
804 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
805 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
806 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
807 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
808 }
else if ((threadIdx.y/4+k+16) < k_size) {
809 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
810 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
811 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
812 }
else if ((threadIdx.y/4+k+8) < k_size) {
813 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
814 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
815 }
else if ((threadIdx.y/4+k) < k_size) {
816 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
820 if (lhs_vert + 3 < m_size) {
821 if ((threadIdx.y/4+k+24) < k_size) {
822 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
823 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
824 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
825 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
826 }
else if ((threadIdx.y/4+k+16) < k_size) {
827 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
828 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
829 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
830 }
else if ((threadIdx.y/4+k+8) < k_size) {
831 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
832 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
833 }
else if ((threadIdx.y/4+k) < k_size) {
834 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
836 }
else if (lhs_vert + 2 < m_size) {
837 if ((threadIdx.y/4+k+24) < k_size) {
838 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
839 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
840 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
841 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
842 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
843 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
844 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
845 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
846 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
847 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
848 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
849 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
850 }
else if ((threadIdx.y/4+k+16) < k_size) {
851 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
852 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
853 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
854 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
855 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
856 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
857 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
858 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
859 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
860 }
else if ((threadIdx.y/4+k+8) < k_size) {
861 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
862 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
863 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
864 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
865 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
866 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
867 }
else if ((threadIdx.y/4+k) < k_size) {
868 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
869 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
870 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
872 }
else if (lhs_vert + 1 < m_size) {
873 if ((threadIdx.y/4+k+24) < k_size) {
874 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
875 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
876 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
877 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
878 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
879 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
880 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
881 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
882 }
else if ((threadIdx.y/4+k+16) < k_size) {
883 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
884 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
885 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
886 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
887 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
888 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
889 }
else if ((threadIdx.y/4+k+8) < k_size) {
890 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
891 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
892 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
893 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
894 }
else if ((threadIdx.y/4+k) < k_size) {
895 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
896 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
898 }
else if (lhs_vert < m_size) {
899 if ((threadIdx.y/4+k+24) < k_size) {
900 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
901 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
902 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
903 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
904 }
else if ((threadIdx.y/4+k+16) < k_size) {
905 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
906 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
907 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
908 }
else if ((threadIdx.y/4+k+8) < k_size) {
909 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
910 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
911 }
else if ((threadIdx.y/4+k) < k_size) {
912 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
917 Index rhs_vert = k+threadIdx.x*4;
918 Index rhs_horiz0 = threadIdx.y*2+base_n;
919 Index rhs_horiz1 = threadIdx.y*2+1+base_n;
920 if (!CHECK_RHS_BOUNDARY) {
921 if ((rhs_vert + 3) < k_size) {
923 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
924 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
925 }
else if (rhs_vert + 2 < k_size) {
927 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
928 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
929 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
930 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
931 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
932 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
933 }
else if (rhs_vert + 1 < k_size) {
934 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
935 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
936 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
937 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
938 }
else if (rhs_vert < k_size) {
939 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
940 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
943 if (rhs_horiz1 < n_size) {
944 if ((rhs_vert + 3) < k_size) {
946 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
947 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
948 }
else if (rhs_vert + 2 < k_size) {
950 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
951 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
952 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
953 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
954 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
955 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
956 }
else if (k+threadIdx.x*4 + 1 < k_size) {
957 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
958 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
959 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
960 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
961 }
else if (k+threadIdx.x*4 < k_size) {
962 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
963 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
965 }
else if (rhs_horiz0 < n_size) {
966 if ((rhs_vert + 3) < k_size) {
968 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
969 }
else if ((rhs_vert + 2) < k_size) {
971 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
972 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
973 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
974 }
else if ((rhs_vert + 1) < k_size) {
975 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
976 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
977 }
else if (rhs_vert < k_size) {
978 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
988 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
992 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
995 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
998 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
1008 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
1009 results[0].x += a_feat1.x * f1.x;\
1010 results[1].x += a_feat1.x * f1.y;\
1011 results[2].x += a_feat1.x * f2.x;\
1012 results[3].x += a_feat1.x * f2.y;\
1013 results[4].x += a_feat1.x * f3.x;\
1014 results[5].x += a_feat1.x * f3.y;\
1015 results[6].x += a_feat1.x * f4.x;\
1016 results[7].x += a_feat1.x * f4.y;\
1018 results[0].y += a_feat1.y * f1.x;\
1019 results[1].y += a_feat1.y * f1.y;\
1020 results[2].y += a_feat1.y * f2.x;\
1021 results[3].y += a_feat1.y * f2.y;\
1022 results[4].y += a_feat1.y * f3.x;\
1023 results[5].y += a_feat1.y * f3.y;\
1024 results[6].y += a_feat1.y * f4.x;\
1025 results[7].y += a_feat1.y * f4.y;\
1027 results[0].z += a_feat2.x * f1.x;\
1028 results[1].z += a_feat2.x * f1.y;\
1029 results[2].z += a_feat2.x * f2.x;\
1030 results[3].z += a_feat2.x * f2.y;\
1031 results[4].z += a_feat2.x * f3.x;\
1032 results[5].z += a_feat2.x * f3.y;\
1033 results[6].z += a_feat2.x * f4.x;\
1034 results[7].z += a_feat2.x * f4.y;\
1036 results[0].w += a_feat2.y * f1.x;\
1037 results[1].w += a_feat2.y * f1.y;\
1038 results[2].w += a_feat2.y * f2.x;\
1039 results[3].w += a_feat2.y * f2.y;\
1040 results[4].w += a_feat2.y * f3.x;\
1041 results[5].w += a_feat2.y * f3.y;\
1042 results[6].w += a_feat2.y * f4.x;\
1043 results[7].w += a_feat2.y * f4.y;\
1045 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1046 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1047 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1048 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1050 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1051 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1052 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1053 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1059 for (
int koff = 0; koff < 32; koff ++) {
1060 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1061 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1064 int start_feature = (threadIdx.y / 4) * 8;
1066 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4];
1067 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1068 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1069 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1071 add_vals(a3, a4, br1, br2, br3, br4)
1077 Index horiz_base = (threadIdx.y/4)*8+base_n;
1078 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1079 for (
int i = 0;
i < 8;
i++) {
1080 output(lhs_vert, horiz_base + i) = results[
i].x;
1081 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
1082 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
1083 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
1085 }
else if (!CHECK_RHS_BOUNDARY) {
1086 if (lhs_vert + 3 < m_size) {
1087 for (
int i = 0;
i < 8;
i++) {
1088 output(lhs_vert, horiz_base + i) = results[
i].x;
1089 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
1090 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
1091 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
1093 }
else if (lhs_vert + 2 < m_size) {
1094 for (
int i = 0;
i < 8;
i++) {
1095 output(lhs_vert, horiz_base + i) = results[
i].x;
1096 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
1097 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
1099 }
else if (lhs_vert + 1 < m_size) {
1100 for (
int i = 0;
i < 8;
i++) {
1101 output(lhs_vert, horiz_base + i) = results[
i].x;
1102 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
1104 }
else if (lhs_vert < m_size) {
1105 for (
int i = 0;
i < 8;
i++) {
1106 output(lhs_vert, horiz_base + i) = results[
i].x;
1109 }
else if (!CHECK_LHS_BOUNDARY) {
1111 for (
int i = 0;
i < 8;
i++) {
1112 if (horiz_base + i < n_size) {
1113 output(lhs_vert, horiz_base + i) = results[
i].x;
1114 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
1115 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
1116 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
1121 for (
int i = 0;
i < 8;
i++) {
1122 if (horiz_base + i < n_size) {
1123 if (lhs_vert < m_size)
1124 output(lhs_vert, horiz_base + i) = results[
i].x;
1125 if (lhs_vert + 1 < m_size)
1126 output(lhs_vert + 1, horiz_base + i) = results[
i].y;
1127 if (lhs_vert + 2 < m_size)
1128 output(lhs_vert + 2, horiz_base + i) = results[
i].z;
1129 if (lhs_vert + 3 < m_size)
1130 output(lhs_vert + 3, horiz_base + i) = results[
i].w;
1137 template<
typename Index,
typename LhsMapper,
1138 typename RhsMapper,
typename OutputMapper>
1140 #if defined(EIGEN_HIPCC)
1141 __launch_bounds__(256, 1)
1143 __launch_bounds__(256)
1145 EigenFloatContractionKernel(
const LhsMapper lhs,
const RhsMapper rhs,
1146 const OutputMapper output,
1148 __shared__ float2 lhs_shmem[64*32];
1149 __shared__ float2 rhs_shmem[128*8];
1151 typedef float2 LHS_MEM[64][32];
1152 typedef float2 RHS_MEM[128][8];
1154 const Index m_block_idx = blockIdx.x;
1155 const Index n_block_idx = blockIdx.y;
1157 const Index base_m = 128 * m_block_idx;
1158 const Index base_n = 64 * n_block_idx;
1160 bool check_rhs = (base_n + 63) >= n_size;
1161 bool check_lhs128 = (base_m + 127) >= m_size;
1164 if (!check_lhs128) {
1166 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1167 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1170 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1173 if (!check_lhs128) {
1175 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1176 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1178 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1179 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1184 template<
typename Index,
typename LhsMapper,
1185 typename RhsMapper,
typename OutputMapper>
1187 #if defined(EIGEN_HIPCC)
1188 __launch_bounds__(256, 1)
1190 __launch_bounds__(256)
1192 EigenFloatContractionKernel16x16(
const LhsMapper lhs,
const RhsMapper rhs,
1193 const OutputMapper output,
1195 __shared__ float2 lhs_shmem[32][16];
1196 __shared__ float2 rhs_shmem[64][8];
1198 const Index m_block_idx = blockIdx.x;
1199 const Index n_block_idx = blockIdx.y;
1201 const Index base_m = 64 * m_block_idx;
1202 const Index base_n = 64 * n_block_idx;
1204 if (base_m + 63 < m_size) {
1205 if (base_n + 63 < n_size) {
1206 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1208 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1211 if (base_n + 63 < n_size) {
1212 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1214 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1220 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
1221 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
1222 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1224 typedef GpuDevice Device;
1226 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1227 typedef TensorContractionEvaluatorBase<Self> Base;
1229 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
XprType;
1230 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
1231 typedef typename XprType::Index
Index;
1241 typedef std::conditional_t<Layout == static_cast<int>(
ColMajor), LeftArgType, RightArgType> EvalLeftArgType;
1242 typedef std::conditional_t<Layout == static_cast<int>(
ColMajor), RightArgType, LeftArgType> EvalRightArgType;
1244 static constexpr
int LDims =
1245 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1246 static constexpr
int RDims =
1247 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1248 static constexpr
int ContractDims = internal::array_size<Indices>::value;
1254 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
1255 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
1257 static constexpr
int NumDims = LDims + RDims - 2 * ContractDims;
1262 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
1263 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
1265 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1266 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1268 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1269 typedef typename RightEvaluator::Dimensions RightDimensions;
1274 EIGEN_STATIC_ASSERT( (internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
1275 GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1280 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1281 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1287 evalTo(this->m_result);
1292 void evalTo(
Scalar* buffer)
const {
1293 if (this->m_lhs_inner_dim_contiguous) {
1294 if (this->m_rhs_inner_dim_contiguous) {
1295 if (this->m_rhs_inner_dim_reordered) {
1296 evalTyped<true, true, true, Unaligned>(buffer);
1299 evalTyped<true, true, false, Unaligned>(buffer);
1303 if (this->m_rhs_inner_dim_reordered) {
1304 evalTyped<true, false, true, Unaligned>(buffer);
1307 evalTyped<true, false, false, Unaligned>(buffer);
1312 if (this->m_rhs_inner_dim_contiguous) {
1313 if (this->m_rhs_inner_dim_reordered) {
1314 evalTyped<false, true, true, Unaligned>(buffer);
1317 evalTyped<false, true, false, Unaligned>(buffer);
1321 if (this->m_rhs_inner_dim_reordered) {
1322 evalTyped<false, false, true, Unaligned>(buffer);
1325 evalTyped<false, false, false, Unaligned>(buffer);
1331 template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
struct LaunchKernels {
1332 static void Run(
const LhsMapper& lhs,
const RhsMapper& rhs,
const OutputMapper& output,
Index m,
Index n,
Index k,
const GpuDevice& device) {
1333 const Index m_blocks = (
m + 63) / 64;
1334 const Index n_blocks = (
n + 63) / 64;
1335 const dim3 num_blocks(m_blocks, n_blocks, 1);
1336 const dim3 block_size(8, 8, 8);
1337 LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1341 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
struct LaunchKernels<float, float,
Index, LhsMapper, RhsMapper, OutputMapper> {
1342 static void Run(
const LhsMapper& lhs,
const RhsMapper& rhs,
const OutputMapper& output,
Index m,
Index n,
Index k,
const GpuDevice& device) {
1343 if (m < 768 || n < 768) {
1344 const Index m_blocks = (
m + 63) / 64;
1345 const Index n_blocks = (
n + 63) / 64;
1346 const dim3 num_blocks(m_blocks, n_blocks, 1);
1347 const dim3 block_size(16, 16, 1);
1348 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1350 const Index m_blocks = (
m + 127) / 128;
1351 const Index n_blocks = (
n + 63) / 64;
1352 const dim3 num_blocks(m_blocks, n_blocks, 1);
1353 const dim3 block_size(8, 32, 1);
1354 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1359 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1360 void evalTyped(
Scalar* buffer)
const {
1362 const Index k = this->m_k_size;
1366 const Index m = this->m_i_size;
1369 const Index n = this->m_j_size;
1375 LeftEvaluator, left_nocontract_t,
1377 lhs_inner_dim_contiguous,
1381 RightEvaluator, right_nocontract_t,
1383 rhs_inner_dim_contiguous,
1384 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
1386 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1390 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1391 this->m_left_contracting_strides, this->m_k_strides);
1393 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1394 this->m_right_contracting_strides, this->m_k_strides);
1396 OutputMapper output(buffer, m);
1398 #if defined(EIGEN_USE_HIP)
1399 setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1401 setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1404 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->
m_device);
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_STATIC_ASSERT(X, MSG)
EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
internal::packet_traits< Scalar >::type type
const Dimensions & dimensions() const
static constexpr int Layout
const Device EIGEN_DEVICE_REF m_device
TensorEvaluator(const Derived &m, const Device &device)
EvaluatorPointerType data() const
Derived::Scalar CoeffReturnType
bool evalSubExprsIfNeeded(EvaluatorPointerType dest)
PacketType< CoeffReturnType, Device >::type PacketReturnType
Derived::Dimensions Dimensions