10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_REF_H
19 template <
typename Dimensions,
typename Scalar>
20 class TensorLazyBaseEvaluator {
22 TensorLazyBaseEvaluator() : m_refcount(0) { }
23 virtual ~TensorLazyBaseEvaluator() { }
31 void incrRefCount() { ++m_refcount; }
32 void decrRefCount() { --m_refcount; }
33 int refCount()
const {
return m_refcount; }
37 TensorLazyBaseEvaluator(
const TensorLazyBaseEvaluator& other);
38 TensorLazyBaseEvaluator& operator = (
const TensorLazyBaseEvaluator& other);
44 template <
typename Dimensions,
typename Expr,
typename Device>
45 class TensorLazyEvaluatorReadOnly :
public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
49 typedef StorageMemory<Scalar, Device> Storage;
51 typedef TensorEvaluator<Expr, Device> EvalType;
53 TensorLazyEvaluatorReadOnly(
const Expr& expr,
const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
54 m_dims = m_impl.dimensions();
55 m_impl.evalSubExprsIfNeeded(NULL);
57 virtual ~TensorLazyEvaluatorReadOnly() {
69 return m_impl.coeff(index);
72 eigen_assert(
false &&
"can't reference the coefficient of a rvalue");
77 TensorEvaluator<Expr, Device> m_impl;
82 template <
typename Dimensions,
typename Expr,
typename Device>
83 class TensorLazyEvaluatorWritable :
public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
85 typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
86 typedef typename Base::Scalar Scalar;
87 typedef StorageMemory<Scalar, Device> Storage;
90 TensorLazyEvaluatorWritable(
const Expr& expr,
const Device& device) : Base(expr, device) {
92 virtual ~TensorLazyEvaluatorWritable() {
96 return this->m_impl.coeffRef(index);
100 template <
typename Dimensions,
typename Expr,
typename Device>
101 class TensorLazyEvaluator :
public std::conditional_t<bool(internal::is_lvalue<Expr>::value),
102 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
103 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> > {
105 typedef std::conditional_t<bool(internal::is_lvalue<Expr>::value),
106 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
107 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> > Base;
108 typedef typename Base::Scalar Scalar;
110 TensorLazyEvaluator(
const Expr& expr,
const Device& device) : Base(expr, device) {
112 virtual ~TensorLazyEvaluator() {
130 typedef typename PlainObjectType::Base
Base;
131 typedef typename Eigen::internal::nested<Self>::type
Nested;
132 typedef typename internal::traits<PlainObjectType>::StorageKind
StorageKind;
133 typedef typename internal::traits<PlainObjectType>::Index
Index;
134 typedef typename internal::traits<PlainObjectType>::Scalar
Scalar;
143 static constexpr
int Layout = PlainObjectType::Layout;
160 template <
typename Expression>
165 template <
typename Expression>
183 if (
this != &other) {
212 const std::size_t num_indices = (
sizeof...(otherIndices) + 1);
214 return coeff(indices);
219 const std::size_t num_indices = (
sizeof...(otherIndices) + 1);
229 if (PlainObjectType::Options &
RowMajor) {
232 index = index * dims[
i] + indices[
i];
237 index = index * dims[
i] + indices[
i];
247 if (PlainObjectType::Options &
RowMajor) {
250 index = index * dims[
i] + indices[
i];
255 index = index * dims[
i] + indices[
i];
283 internal::TensorLazyBaseEvaluator<Dimensions, Scalar>*
m_evaluator;
288 template<
typename Derived,
typename Device>
291 typedef typename Derived::Index
Index;
304 PreferBlockAccess =
false,
326 return m_ref.coeff(index);
330 return m_ref.coeffRef(index);
341 template<
typename Derived,
typename Device>
344 typedef typename Derived::Index
Index;
356 PreferBlockAccess =
false,
368 return this->m_ref.coeffRef(index);
#define EIGEN_DEVICE_FUNC
A reference to a tensor expression The expression will be evaluated lazily (as much as possible).
internal::traits< PlainObjectType >::Index Index
TensorRef(const TensorRef &other)
const Scalar coeff(const array< Index, NumIndices > &indices) const
PlainObjectType::Dimensions Dimensions
Scalar & coeffRef(Index index)
internal::TensorLazyBaseEvaluator< Dimensions, Scalar > * m_evaluator
Base::CoeffReturnType CoeffReturnType
NumTraits< Scalar >::Real RealScalar
static constexpr Index NumIndices
TensorRef(const Expression &expr)
Index dimension(Index n) const
PlainObjectType::Base Base
Scalar & coeffRef(Index firstIndex, IndexTypes... otherIndices)
PointerType PointerArgType
const Scalar * data() const
TensorRef< PlainObjectType > Self
Scalar & coeffRef(const array< Index, NumIndices > &indices)
Eigen::internal::nested< Self >::type Nested
internal::TensorBlockNotImplemented TensorBlock
const Scalar coeff(Index index) const
const Scalar operator()(Index index) const
const Dimensions & dimensions() const
internal::traits< PlainObjectType >::StorageKind StorageKind
internal::traits< PlainObjectType >::Scalar Scalar
static constexpr int Layout
const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
TensorRef & operator=(const Expression &expr)
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex
internal::packet_traits< Scalar >::type type
internal::TensorBlockNotImplemented TensorBlock
PacketType< CoeffReturnType, Device >::type PacketReturnType
TensorEvaluator< const TensorRef< Derived >, Device > Base
Derived::Scalar CoeffReturnType
Derived::Dimensions Dimensions
Scalar & coeffRef(Index index)
TensorEvaluator(TensorRef< Derived > &m, const Device &d)
TensorRef< Derived > m_ref
TensorEvaluator(const TensorRef< Derived > &m, const Device &)
CoeffReturnType coeff(Index index) const
const Scalar * data() const
PacketType< CoeffReturnType, Device >::type PacketReturnType
Derived::Dimensions Dimensions
internal::TensorBlockNotImplemented TensorBlock
Derived::Scalar CoeffReturnType
Scalar & coeffRef(Index index)
const Dimensions & dimensions() const
Storage::Type EvaluatorPointerType
bool evalSubExprsIfNeeded(EvaluatorPointerType)
StorageMemory< CoeffReturnType, Device > Storage
A cost model used to limit the number of threads used for evaluating tensor expression.
static constexpr int Layout