10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
29 template <
typename Tensor,
bool HasRawAccess,
template <
class>
class MakePointer_ =
MakePointer>
32 template <
typename Scalar,
typename Index,
int side,
typename Tensor,
33 typename nocontract_t,
typename contract_t,
int packet_size,
34 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
36 class BaseTensorContractionMapper;
38 template <
typename Tensor,
bool HasRawAccess,
template <
class>
class MakePointer_>
59 typename Tensor::PacketReturnType packet(
typename Tensor::Index index)
const
61 return m_tensor.template packet<LoadMode>(index);
65 const Tensor m_tensor;
68 template <
typename Tensor,
template <
class>
class MakePointer_>
69 struct CoeffLoader<Tensor, true, MakePointer_> {
88 typename Tensor::PacketReturnType packet(
typename Tensor::Index index)
const
90 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
96 typename MakePointer_<const Scalar>::Type m_data;
99 template<
typename Scalar,
typename Index,
int side,
101 typename nocontract_t,
typename contract_t,
102 int packet_size,
bool inner_dim_contiguous,
int Alignment,
template <
class>
class MakePointer_ = MakePointer>
103 class SimpleTensorContractionMapper {
106 SimpleTensorContractionMapper(
const Tensor& tensor,
107 const nocontract_t& nocontract_strides,
108 const nocontract_t& ij_strides,
109 const contract_t& contract_strides,
110 const contract_t& k_strides) :
112 m_nocontract_strides(nocontract_strides),
113 m_ij_strides(ij_strides),
114 m_contract_strides(contract_strides),
115 m_k_strides(k_strides) { }
118 DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets
122 m_tensor.offsetBuffer(offset);
136 return m_tensor.coeff(computeIndex(row, col));
141 const bool left = (side ==
Lhs);
146 for (
int i =
static_cast<int>(array_size<nocontract_t>::value) - 1;
i > 0;
i--) {
147 const Index idx = nocontract_val / m_ij_strides[
i];
148 linidx += idx * m_nocontract_strides[
i];
149 nocontract_val -= idx * m_ij_strides[
i];
151 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
152 if (side ==
Lhs && inner_dim_contiguous) {
154 linidx += nocontract_val;
156 linidx += nocontract_val * m_nocontract_strides[0];
161 if(array_size<contract_t>::value > 0) {
163 for (
int i =
static_cast<int>(array_size<contract_t>::value) - 1;
i > 0;
i--) {
164 const Index idx = contract_val / m_k_strides[
i];
165 linidx += idx * m_contract_strides[
i];
166 contract_val -= idx * m_k_strides[
i];
169 if (side ==
Rhs && inner_dim_contiguous) {
171 linidx += contract_val;
173 linidx += contract_val * m_contract_strides[0];
181 EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(
Index row,
Index col,
const Index distance)
const {
182 const bool left = (side ==
Lhs);
185 Index linidx[2] = {0, 0};
186 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
188 for (
int i =
static_cast<int>(array_size<nocontract_t>::value) - 1;
i > 0;
i--) {
189 const Index idx0 = nocontract_val[0] / m_ij_strides[
i];
190 const Index idx1 = nocontract_val[1] / m_ij_strides[
i];
191 linidx[0] += idx0 * m_nocontract_strides[
i];
192 linidx[1] += idx1 * m_nocontract_strides[
i];
193 nocontract_val[0] -= idx0 * m_ij_strides[
i];
194 nocontract_val[1] -= idx1 * m_ij_strides[
i];
196 if (side ==
Lhs && inner_dim_contiguous) {
198 linidx[0] += nocontract_val[0];
199 linidx[1] += nocontract_val[1];
201 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
202 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
207 if (array_size<contract_t>::value> 0) {
209 for (
int i =
static_cast<int>(array_size<contract_t>::value) - 1;
i > 0;
i--) {
210 const Index idx0 = contract_val[0] / m_k_strides[
i];
211 const Index idx1 = contract_val[1] / m_k_strides[
i];
212 linidx[0] += idx0 * m_contract_strides[
i];
213 linidx[1] += idx1 * m_contract_strides[
i];
214 contract_val[0] -= idx0 * m_k_strides[
i];
215 contract_val[1] -= idx1 * m_k_strides[
i];
218 if (side ==
Rhs && inner_dim_contiguous) {
220 linidx[0] += contract_val[0];
221 linidx[1] += contract_val[1];
223 linidx[0] += contract_val[0] * m_contract_strides[0];
224 linidx[1] += contract_val[1] * m_contract_strides[0];
227 return IndexPair<Index>(linidx[0], linidx[1]);
234 return (Alignment ==
Aligned) && (side ==
Lhs) && inner_dim_contiguous ? 0 : size;
237 return ((side ==
Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
240 const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor()
const {
244 const nocontract_t& nocontract_strides()
const {
245 return m_nocontract_strides;
247 const nocontract_t& ij_strides()
const {
return m_ij_strides; }
248 const contract_t& contract_strides()
const {
return m_contract_strides; }
249 const contract_t& k_strides()
const {
return m_k_strides; }
252 CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
253 const nocontract_t m_nocontract_strides;
254 const nocontract_t m_ij_strides;
255 const contract_t m_contract_strides;
256 const contract_t m_k_strides;
259 template<
typename Scalar,
typename Index,
int side,
261 typename nocontract_t,
typename contract_t,
262 int packet_size,
bool inner_dim_contiguous,
263 bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_>
264 class BaseTensorContractionMapper :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_>
267 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
270 BaseTensorContractionMapper(
const Tensor& tensor,
271 const nocontract_t& nocontract_strides,
272 const nocontract_t& ij_strides,
273 const contract_t& contract_strides,
274 const contract_t& k_strides) :
275 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
277 template <
typename PacketT,
int AlignmentType>
279 std::enable_if_t<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>
288 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
289 const Index index = this->computeIndex(i, j);
290 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
291 return this->m_tensor.template packet<AlignmentType>(index);
294 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
296 const Index lastIdx = indexPair.second;
302 if (Tensor::PacketAccess &&
303 (side ==
Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
304 (lastIdx -
first) == (packet_size - 1)) {
306 return this->m_tensor.template packet<AlignmentType>(
first);
313 for (
Index k = 1; k < packet_size - 1; k += 2) {
314 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
315 data[k] = this->m_tensor.coeff(internal_pair.first);
316 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
318 data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
320 return pload<PacketT>(data);
323 template <
typename PacketT,
int AlignmentType>
325 std::enable_if_t<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>
328 const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
331 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
333 const Index lastIdx = indexPair.second;
336 for (
Index k = 1; k < requested_packet_size - 1; k += 2) {
337 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
338 data[k] = this->m_tensor.coeff(internal_pair.first);
339 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
341 data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
343 return pload<PacketT>(data);
346 template <
typename PacketT,
int AlignmentType>
348 EIGEN_STRONG_INLINE PacketT loadPacket(
Index i,
Index j)
const {
349 return this->load<PacketT,AlignmentType>(i,j);
354 template<
typename Scalar,
typename Index,
int side,
356 typename nocontract_t,
typename contract_t,
357 bool inner_dim_contiguous,
358 bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_>
359 class BaseTensorContractionMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
360 :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_>
363 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
366 BaseTensorContractionMapper(
const Tensor& tensor,
367 const nocontract_t& nocontract_strides,
368 const nocontract_t& ij_strides,
369 const contract_t& contract_strides,
370 const contract_t& k_strides) :
371 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
374 EIGEN_STRONG_INLINE PacketT loadPacket(
Index i,
Index j)
const {
376 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
377 return pload<PacketT>(data);
380 EIGEN_STRONG_INLINE PacketT load(
Index i,
Index j)
const {
382 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
383 return pload<PacketT>(data);
388 template<
typename Scalar,
typename Index,
int side,
390 typename nocontract_t,
typename contract_t,
392 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_=MakePointer>
393 class TensorContractionSubMapper {
396 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> ParentMapper;
397 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Self;
398 typedef Self LinearMapper;
399 typedef Self SubMapper;
404 UseDirectOffsets = ParentMapper::DirectOffsets && (side ==
Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
408 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
411 if (UseDirectOffsets) {
412 Index stride = m_base_mapper.stride();
413 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
418 if (UseDirectOffsets) {
419 return m_base_mapper(i, 0);
421 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
424 if (UseDirectOffsets) {
425 return m_base_mapper(i, j);
427 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
430 template <
typename PacketT>
432 if (UseDirectOffsets) {
433 return m_base_mapper.template loadPacket<PacketT,Alignment>(i, 0);
435 return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, m_horiz_offset);
438 template <
typename PacketT>
440 if (UseDirectOffsets) {
441 return m_base_mapper.template loadPacket<PacketT,Alignment>(i, j);
443 return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset);
446 template <
typename PacketT,
int AlignmentType>
448 if (UseDirectOffsets) {
449 return m_base_mapper.template load<PacketT,AlignmentType>(i, j);
451 return m_base_mapper.template loadPacket<PacketT,AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
454 template <
typename PacketT>
456 if (UseDirectOffsets) {
457 m_base_mapper.storePacket(i, 0, p);
459 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
463 if (UseDirectOffsets) {
464 return LinearMapper(m_base_mapper, i, j);
466 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
470 if (UseDirectOffsets) {
471 return SubMapper(m_base_mapper, i, j);
473 return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
476 template <
typename PacketT,
int AlignmentType>
478 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
480 if (UseDirectOffsets) {
481 return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i, 0);
483 return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i + m_vert_offset, m_horiz_offset);
486 template <
typename PacketT>
491 const ParentMapper& base_mapper()
const {
return m_base_mapper; }
492 Index vert_offset()
const {
return m_vert_offset; }
493 Index horiz_offset()
const {
return m_horiz_offset; }
496 ParentMapper m_base_mapper;
497 const Index m_vert_offset;
498 const Index m_horiz_offset;
502 template<
typename Scalar_,
typename Index,
int side,
504 typename nocontract_t,
typename contract_t,
506 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_=MakePointer>
507 class TensorContractionInputMapper
508 :
public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
511 typedef Scalar_ Scalar;
512 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Base;
513 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> SubMapper;
514 typedef SubMapper VectorMapper;
515 typedef SubMapper LinearMapper;
518 const nocontract_t& nocontract_strides,
519 const nocontract_t& ij_strides,
520 const contract_t& contract_strides,
521 const contract_t& k_strides)
522 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
525 EIGEN_STRONG_INLINE SubMapper getSubMapper(
Index i,
Index j)
const {
526 return SubMapper(*
this, i, j);
530 return LinearMapper(*
this, i, j);
534 return VectorMapper(*
this, i, j);
538 return Base::m_tensor;
543 template <
typename T>
struct TensorContractionInputMapperTrait;
545 template<
typename Scalar_,
typename Index_,
int side_,
547 typename nocontract_t_,
typename contract_t_,
549 bool inner_dim_contiguous_,
bool inner_dim_reordered_,
int Alignment_,
template <
class>
class MakePointer_>
550 struct TensorContractionInputMapperTrait<TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_,
551 nocontract_t_, contract_t_, packet_size_, inner_dim_contiguous_,
552 inner_dim_reordered_, Alignment_, MakePointer_> > {
554 typedef Tensor_ XprType;
555 static const bool inner_dim_contiguous = inner_dim_contiguous_;
556 static const bool inner_dim_reordered = inner_dim_reordered_;
RowXpr row(Index i) const
ColXpr col(Index i) const
IndexedView_or_VectorBlock operator()(const Indices &indices)
#define EIGEN_ALWAYS_INLINE
#define EIGEN_UNROLL_LOOP
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_DEVICE_FUNC
#define EIGEN_STATIC_ASSERT(X, MSG)
internal::traits< Self >::Index Index
EIGEN_ALWAYS_INLINE T loadConstant(const T *address)
void prefetch(const Scalar *addr)
EIGEN_CONSTEXPR Index first(const T &x) EIGEN_NOEXCEPT
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index